diff --git a/README.md b/README.md index 7785ac548..a1a49e57d 100644 --- a/README.md +++ b/README.md @@ -1,39 +1,104 @@ -![Build Status](https://azure-iot-sdks.visualstudio.com/azure-iot-sdks/_apis/build/status/python/python-preview) +# +
+ +

V2 - We are now GA!

+
-# Azure IoT Hub Python SDKs v2 - PREVIEW +![Build Status](https://azure-iot-sdks.visualstudio.com/azure-iot-sdks/_apis/build/status/Azure.azure-iot-sdk-python) -This repository contains the code for the future v2.0.0 of the Azure IoT SDKs for Python. The goal of v2.0.0 is to be a complete rewrite of the existing SDK that maximizes the use of the Python language and its standard features rather than wrap over the C SDK, like v1.x.x of the SDK did. +This repository contains code for the Azure IoT SDKs for Python. This enables python developers to easily create IoT device solutions that semealessly +connection to the Azure IoTHub ecosystem. *If you're looking for the v1.x.x client library, it is now preserved in the [v1-deprecated](https://github.com/Azure/azure-iot-sdk-python/tree/v1-deprecated) branch.* -**Note that these SDKs are currently in preview, and are subject to change.** -# SDKs +## Azure IoT SDK for Python -This repository contains the following SDKs: +This repository contains the following libraries: -* [Azure IoT Device SDK](azure-iot-device) - /azure-iot-device - * Provision a device using the Device Provisioning Service for use with the Azure IoT hub - * Send/receive telemetry between a device or module and the Azure IoT hub or Azure IoT Edge device - * Handle direct methods invoked by the Azure IoT hub on a device - * Handle twin events and report twin updates - * *Still in development* - - *Blob/File upload* - - *Invoking method from a module client onto a leaf device* +* [Azure IoT Device library](https://github.com/Azure/azure-iot-sdk-python/blob/master/azure-iot-device/README.md) -* Azure IoT Hub SDK **(COMING SOON)** - * Do service/management operations on the Azure IoT Hub +* [Azure IoT Hub Service library](https://github.com/Azure/azure-iot-sdk-python/blob/master/azure-iot-hub/README.md) -* Azure IoT Hub Provisioning SDK **(COMING SOON)** - * Do service/management operations on the Azure IoT Device Provisioning Service +* Coming Soon: Azure IoT Device Provisioning Service Library -# How to install the SDKs +## Installing the libraries -``` -pip install azure-iot-device -``` +Pip installs are provided for all of the SDK libraries in this repo: -# Contributing +[Device libraries](https://github.com/Azure/azure-iot-sdk-python/tree/master/azure-iot-device#installation) + +[IoTHub library](https://github.com/Azure/azure-iot-sdk-python/blob/master/azure-iot-hub/README.md#installation) + +## Features + +:heavy_check_mark: feature available :heavy_multiplication_x: feature planned but not yet supported :heavy_minus_sign: no support planned* + +*Features that are not planned may be prioritized in a future release, but are not currently planned + +### Device Client Library ([azure-iot-device](https://github.com/Azure/azure-iot-sdk-python/tree/master/azure-iot-device)) + +#### IoTHub Device Client + +| Features | Status | Description | +|------------------------------------------------------------------------------------------------------------------|----------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [Authentication](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-security-deployment) | :heavy_check_mark: | Connect your device to IoT Hub securely with supported authentication, including private key, SASToken, X-509 Self Signed and Certificate Authority (CA) Signed. | +| [Send device-to-cloud message](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-messages-d2c) | :heavy_check_mark: | Send device-to-cloud messages (max 256KB) to IoT Hub with the option to add custom properties. | +| [Receive cloud-to-device messages](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-messages-c2d) | :heavy_check_mark: | Receive cloud-to-device messages and read associated custom and system properties from IoT Hub, with the option to complete/reject/abandon C2D messages. | +| [Device Twins](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-device-twins) | :heavy_check_mark: | IoT Hub persists a device twin for each device that you connect to IoT Hub. The device can perform operations like get twin tags, subscribe to desired properties. | +| [Direct Methods](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-direct-methods) | :heavy_check_mark: | IoT Hub gives you the ability to invoke direct methods on devices from the cloud. The SDK supports handler for method specific and generic operation. | +| [Connection Status and Error reporting](https://docs.microsoft.com/en-us/rest/api/iothub/common-error-codes) | :heavy_multiplication_x: | Error reporting for IoT Hub supported error code. *This SDK supports error reporting on authentication and Device Not Found. | +| Retry policies | :heavy_check_mark: | Retry policy for unsuccessful device-to-cloud messages. | +| [Upload file to Blob](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-file-upload) | :heavy_check_mark: | A device can initiate a file upload and notifies IoT Hub when the upload is complete. | + +#### IoTHub Module Client + +| Features | Status | Description | +|------------------------------------------------------------------------------------------------------------------|----------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [Authentication](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-security-deployment) | :heavy_check_mark: | Connect your device to IoT Hub securely with supported authentication, including private key, SASToken, X-509 Self Signed and Certificate Authority (CA) Signed. | +| [Send device-to-cloud message](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-messages-d2c) | :heavy_check_mark: | Send device-to-cloud messages (max 256KB) to IoT Hub with the option to add custom properties. | +| [Receive cloud-to-device messages](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-messages-c2d) | :heavy_check_mark: | Receive cloud-to-device messages and read associated custom and system properties from IoT Hub, with the option to complete/reject/abandon C2D messages. | +| [Device Twins](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-device-twins) | :heavy_check_mark: | IoT Hub persists a device twin for each device that you connect to IoT Hub. The device can perform operations like get twin tags, subscribe to desired properties. | +| [Direct Methods](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-direct-methods) | :heavy_check_mark: | IoT Hub gives you the ability to invoke direct methods on devices from the cloud. The SDK supports handler for method specific and generic operation. | +| [Connection Status and Error reporting](https://docs.microsoft.com/en-us/rest/api/iothub/common-error-codes) | :heavy_multiplication_x: | Error reporting for IoT Hub supported error code. *This SDK supports error reporting on authentication and Device Not Found. | +| Retry policies | :heavy_check_mark: | Retry policy for connecting disconnected devices and resubmitting messages. | +| Direct Invocation of Method on Modules | :heavy_check_mark: | Invoke method calls to another module using using the Edge Gateway. | + +#### Provisioning Device Client + +| Features | Status | Description | +|-----------------------------|--------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| TPM Individual Enrollment | :heavy_minus_sign: | Provisioning via [Trusted Platform Module](https://docs.microsoft.com/en-us/azure/iot-dps/concepts-security#trusted-platform-module-tpm). | +| X.509 Individual Enrollment | :heavy_check_mark: | Provisioning via [X.509 root certificate](https://docs.microsoft.com/en-us/azure/iot-dps/concepts-security#root-certificate). Please review the [samples](./azure-iot-device/samples/async-hub-scenarios/provision_x509_and_send_telemetry.py) folder and this [quickstart](https://docs.microsoft.com/en-us/azure/iot-dps/quick-create-simulated-device-x509-python) on how to create a device client. | +| X.509 Enrollment Group | :heavy_check_mark: | Provisioning via [X.509 leaf certificate](https://docs.microsoft.com/en-us/azure/iot-dps/concepts-security#leaf-certificate)). Please review the [samples](./azure-iot-device/samples/async-hub-scenarios/provision_x509_and_send_telemetry.py) folder on how to create a device client. | +| Symmetric Key Enrollment | :heavy_check_mark: | Provisioning via [Symmetric key attestation](https://docs.microsoft.com/en-us/azure/iot-dps/concepts-symmetric-key-attestation)). Please review the [samples](./azure-iot-device/samples/async-hub-scenarios/provision_symmetric_key_and_send_telemetry.py) folder on how to create a device client. | + +### IoTHub Service Library ([azure-iot-hub](https://github.com/Azure/azure-iot-sdk-python/blob/master/azure-iot-hub/azure/iot/hub/iothub_registry_manager.py)) + +#### Registry Manager + +| Features | Status | Description | +|---------------------------------------------------------------------------------------------------------------|--------------------------|------------------------------------------------------------------------------------------------------------------------------------| +| [Identity registry (CRUD)](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-identity-registry) | :heavy_check_mark: | Use your backend app to perform CRUD operation for individual device or in bulk. | +| [Cloud-to-device messaging](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-messages-c2d) | :heavy_multiplication_x: | Use your backend app to send cloud-to-device messages, and set up cloud-to-device message receivers. | +| [Direct Methods operations](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-direct-methods) | :heavy_check_mark: | Use your backend app to invoke direct method on device. | +| [Device Twins operations](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-device-twins) | :heavy_check_mark: | Use your backend app to perform device twin operations. *Twin reported property update callback and replace twin are in progress. | +| [Query](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-query-language) | :heavy_multiplication_x: | Use your backend app to perform query for information. | +| [Jobs](https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-jobs) | :heavy_multiplication_x: | Use your backend app to perform job operation. | + +### IoTHub Provisioning Service Library + +Feature is Coming Soon + +| Features | Status | Description | +|-----------------------------------------------------|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------| +| CRUD Operation with TPM Individual Enrollment | :heavy_multiplication_x: | Manage device enrollment using TPM with the service SDK. | +| Bulk CRUD Operation with TPM Individual Enrollment | :heavy_multiplication_x: | Bulk manage device enrollment using TPM with the service SDK. | +| CRUD Operation with X.509 Individual Enrollment | :heavy_multiplication_x: | Manages device enrollment using X.509 individual enrollment with the service SDK. | +| CRUD Operation with X.509 Group Enrollment | :heavy_multiplication_x: | Manages device enrollment using X.509 group enrollment with the service SDK. | +| Query enrollments | :heavy_multiplication_x: | Query registration states with the service SDK. | + +## Contributing This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us @@ -46,3 +111,4 @@ provided by the bot. You will only need to do this once across all repos using o This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. + diff --git a/SECURITY.MD b/SECURITY.MD new file mode 100644 index 000000000..8c35a2dff --- /dev/null +++ b/SECURITY.MD @@ -0,0 +1,41 @@ + + +# Security + +Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). + +If you believe you have found a security vulnerability in any Microsoft-owned repository that meets Microsoft's [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)) of a security vulnerability, please report it to us as described below. + +## Reporting Security Issues + +**Please do not report security vulnerabilities through public GitHub issues.** + +Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). + +If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). + +You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). + +Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: + +* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) +* Full paths of source file(s) related to the manifestation of the issue +* The location of the affected source code (tag/branch/commit or direct URL) +* Any special configuration required to reproduce the issue +* Step-by-step instructions to reproduce the issue +* Proof-of-concept or exploit code (if possible) +* Impact of the issue, including how an attacker might exploit the issue + +This information will help us triage your report more quickly. + +If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. + +## Preferred Languages + +We prefer all communications to be in English. + +## Policy + +Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). + + diff --git a/azure-iot-device/.bumpverion.cfg b/azure-iot-device/.bumpverion.cfg index 7cb81dba4..38d73fe88 100644 --- a/azure-iot-device/.bumpverion.cfg +++ b/azure-iot-device/.bumpverion.cfg @@ -1,7 +1,7 @@ [bumpversion] -current_version = 2.0.0-preview.10 -parse = (?P\d+)\.(?P\d+)\.(?P\d+)-preview\.(?P\d+) -serialize = {major}.{minor}.{patch}-preview.{preview} +current_version = 2.1.0 +parse = (?P\d+)\.(?P\d+)\.(?P\d+) +serialize = {major}.{minor}.{patch} -[bumpversion:part:preview] +[bumpversion:part] diff --git a/azure-iot-device/README.md b/azure-iot-device/README.md index d59c087cf..32508068e 100644 --- a/azure-iot-device/README.md +++ b/azure-iot-device/README.md @@ -1,139 +1,56 @@ # Azure IoT Device SDK + The Azure IoT Device SDK for Python provides functionality for communicating with the Azure IoT Hub for both Devices and Modules. -**Note that this SDK is currently in preview, and is subject to change.** +## Azure IoT Device Features -## Features The SDK provides the following clients: * ### Provisioning Device Client - * Creates a device identity on the Azure IoT Hub + + * Creates a device identity on the Azure IoT Hub * ### IoT Hub Device Client - * Send telemetry messages to Azure IoT Hub - * Receive Cloud-to-Device (C2D) messages from the Azure IoT Hub - * Receive and respond to direct method invocations from the Azure IoT Hub + + * Send telemetry messages to Azure IoT Hub + * Receive Cloud-to-Device (C2D) messages from the Azure IoT Hub + * Receive and respond to direct method invocations from the Azure IoT Hub * ### IoT Hub Module Client - * Supports Azure IoT Edge Hub and Azure IoT Hub - * Send telemetry messages to a Hub or to another Module - * Receive Input messages from a Hub or other Modules - * Receive and respond to direct method invocations from a Hub or other Modules + + * Supports Azure IoT Edge Hub and Azure IoT Hub + * Send telemetry messages to a Hub or to another Module + * Receive Input messages from a Hub or other Modules + * Receive and respond to direct method invocations from a Hub or other Modules These clients are available with an asynchronous API, as well as a blocking synchronous API for compatibility scenarios. **We recommend you use Python 3.7+ and the asynchronous API.** | Python Version | Asynchronous API | Synchronous API | | -------------- | ---------------- | --------------- | | Python 3.5.3+ | **YES** | **YES** | -| Python 3.4 | NO | **YES** | | Python 2.7 | NO | **YES** | ## Installation -``` + +```Shell pip install azure-iot-device ``` -## Set up an IoT Hub and create a Device Identity -1. Install the [Azure CLI](https://docs.microsoft.com/en-us/cli/azure/install-azure-cli?view=azure-cli-latest) (or use the [Azure Cloud Shell](https://shell.azure.com/)) and use it to [create an Azure IoT Hub](https://docs.microsoft.com/en-us/cli/azure/iot/hub?view=azure-cli-latest#az-iot-hub-create). +## Device Samples - ```bash - az iot hub create --resource-group --name - ``` - * Note that this operation make take a few minutes. +Check out the [samples repository](./azure-iot-device/samples) for example code showing how the SDK can be used in a variety of scenarios, including: -2. Add the IoT Extension to the Azure CLI, and then [register a device identity](https://docs.microsoft.com/en-us/cli/azure/ext/azure-cli-iot-ext/iot/hub/device-identity?view=azure-cli-latest#ext-azure-cli-iot-ext-az-iot-hub-device-identity-create) - - ```bash - az extension add --name azure-cli-iot-ext - az iot hub device-identity create --hub-name --device-id - ``` - -2. [Retrieve your Device Connection String](https://docs.microsoft.com/en-us/cli/azure/ext/azure-cli-iot-ext/iot/hub/device-identity?view=azure-cli-latest#ext-azure-cli-iot-ext-az-iot-hub-device-identity-show-connection-string) using the Azure CLI - - ```bash - az iot hub device-identity show-connection-string --device-id --hub-name - ``` - - It should be in the format: - ``` - HostName=.azure-devices.net;DeviceId=;SharedAccessKey= - ``` - -## Send a simple telemetry message - -1. [Begin monitoring for telemetry](https://docs.microsoft.com/en-us/cli/azure/ext/azure-cli-iot-ext/iot/hub?view=azure-cli-latest#ext-azure-cli-iot-ext-az-iot-hub-monitor-events) on your IoT Hub using the Azure CLI - - ```bash - az iot hub monitor-events --hub-name --output table - ``` - -2. On your device, set the Device Connection String as an enviornment variable called `IOTHUB_DEVICE_CONNECTION_STRING`. - - ### Windows - ```cmd - set IOTHUB_DEVICE_CONNECTION_STRING= - ``` - * Note that there are **NO** quotation marks around the connection string. - - ### Linux - ```bash - export IOTHUB_DEVICE_CONNECTION_STRING="" - ``` - -3. Copy the following code that sends a single message to the IoT Hub into a new python file on your device, and run it from the terminal or IDE (**requires Python 3.7+**): - - ```python - import asyncio - import os - from azure.iot.device.aio import IoTHubDeviceClient - - - async def main(): - # Fetch the connection string from an enviornment variable - conn_str = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") - - # Create instance of the device client using the connection string - device_client = IoTHubDeviceClient.create_from_connection_string(conn_str) - - # Send a single message - print("Sending message...") - await device_client.send_message("This is a message that is being sent") - print("Message successfully sent!") - - # finally, disconnect - await device_client.disconnect() - - - if __name__ == "__main__": - asyncio.run(main()) - ``` - -4. Check the Azure CLI output to verify that the message was received by the IoT Hub. You should see the following output: - - ```bash - Starting event monitor, use ctrl-c to stop... - event: - origin: - payload: This is a message that is being sent - ``` - -5. Your device is now able to connect to Azure IoT Hub! - -## Additional Samples -Check out the [samples repository](https://github.com/Azure/azure-iot-sdk-python-preview/tree/master/azure-iot-device/samples) for example code showing how the SDK can be used in a variety of scenarios, including: * Sending multiple telemetry messages at once. * Receiving Cloud-to-Device messages. * Using Edge Modules with the Azure IoT Edge Hub. * Send and receive updates to device twin * Receive invocations to direct methods * Register a device with the Device Provisioning Service -* Legacy scenarios for Python 2.7 and 3.4 ## Getting help and finding API docs Our SDK makes use of docstrings which means you cand find API documentation directly through Python with use of the [help](https://docs.python.org/3/library/functions.html#help) command: - ```python >>> from azure.iot.device import IoTHubDeviceClient >>> help(IoTHubDeviceClient) diff --git a/azure-iot-device/azure/iot/device/common/__init__.py b/azure-iot-device/azure/iot/device/common/__init__.py index c10dc4c6e..89a2ca3bb 100644 --- a/azure-iot-device/azure/iot/device/common/__init__.py +++ b/azure-iot-device/azure/iot/device/common/__init__.py @@ -5,6 +5,6 @@ This package provides shared modules for use with various Azure IoT device-side INTERNAL USAGE ONLY """ -from .models import X509 +from .models import X509, ProxyOptions -__all__ = ["X509"] +__all__ = ["X509", "ProxyOptions"] diff --git a/azure-iot-device/azure/iot/device/common/async_adapter.py b/azure-iot-device/azure/iot/device/common/async_adapter.py index f3cb69b5f..d4cf3fca4 100644 --- a/azure-iot-device/azure/iot/device/common/async_adapter.py +++ b/azure-iot-device/azure/iot/device/common/async_adapter.py @@ -7,6 +7,7 @@ import functools import logging +import traceback import azure.iot.device.common.asyncio_compat as asyncio_compat logger = logging.getLogger(__name__) @@ -69,9 +70,9 @@ class AwaitableCallback(object): result = None if exception: - logger.error( - "Callback completed with error {}".format(exception), exc_info=exception - ) + # Do not use exc_info parameter on logger.error. This casuses pytest to save the traceback which saves stack frames which shows up as a leak + logger.error("Callback completed with error {}".format(exception)) + logger.error(traceback.format_exception_only(type(exception), exception)) loop.call_soon_threadsafe(self.future.set_exception, exception) else: logger.debug("Callback completed with result {}".format(result)) diff --git a/azure-iot-device/azure/iot/device/common/callable_weak_method.py b/azure-iot-device/azure/iot/device/common/callable_weak_method.py new file mode 100644 index 000000000..dfb9ba2b3 --- /dev/null +++ b/azure-iot-device/azure/iot/device/common/callable_weak_method.py @@ -0,0 +1,78 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import weakref + + +class CallableWeakMethod(object): + """ + Object which makes a weak reference to a method call. Similar to weakref.WeakMethod, + but works on Python 2.7 and returns an object which is callable. + + This objet is used primarily for callbacks and it prevents circular references in the + garbage collector. It is used specifically in the scenario where object holds a + refernce to object b and b holds a callback into a (which creates a rererence + back into a) + + By default, method references are _strong_, and we end up with we have a situation + where a has a _strong) reference to b and b has a _strong_ reference to a. + + The Python 3.4+ garbage collectors handle this circular reference just fine, but the + 2.7 garbage collector fails, but only when one of the objects has a finalizer method. + + ''' + # example of bad (strong) circular dependency: + class A(object): + def --init__(self): + self.b = B() # A objects now have a strong refernce to B objects + b.handler = a.method() # and B object have a strong reference back into A objects + def method(self): + pass + ''' + + In the example above, if a or B has a finalizer, that object will be considered uncollectable + (on 2.7) and both objects will leak + + However, if we use this object, a will a _strong_ reference to b, and b will have a _weak_ + reference =back to a, and the circular depenency chain is broken. + + ``` + # example of better (weak) circular dependency: + class A(object): + def --init__(self): + self.b = B() # A objects now have a strong refernce to B objects + b.handler = CallableWeakMethod(a, "method") # and B objects have a WEAK reference back into A objects + def method(self): + pass + ``` + + In this example, there is no circular reference, and the Python 2.7 garbage collector is able + to collect both objects, even if one of them has a finalizer. + + When we reach the point where all supported interpreters implement PEP 442, we will + no longer need this object + + ref: https://www.python.org/dev/peps/pep-0442/ + """ + + def __init__(self, object, method_name): + self.object_weakref = weakref.ref(object) + self.method_name = method_name + + def _get_method(self): + return getattr(self.object_weakref(), self.method_name) + + def __call__(self, *args, **kwargs): + return self._get_method()(*args, **kwargs) + + def __eq__(self, other): + return self._get_method() == other + + def __repr__(self): + if self.object_weakref(): + return "CallableWeakMethod for {}".format(self._get_method()) + else: + return "CallableWeakMethod for {} (DEAD)".format(self.method_name) diff --git a/azure-iot-device/azure/iot/device/common/chainable_exception.py b/azure-iot-device/azure/iot/device/common/chainable_exception.py new file mode 100644 index 000000000..bad9523e3 --- /dev/null +++ b/azure-iot-device/azure/iot/device/common/chainable_exception.py @@ -0,0 +1,24 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + + +class ChainableException(Exception): + """This exception stores a reference to a previous exception which has caused + the current one""" + + def __init__(self, message=None, cause=None): + # By using .__cause__, this will allow typical stack trace behavior in Python 3, + # while still being able to operate in Python 2. + self.__cause__ = cause + super(ChainableException, self).__init__(message) + + def __str__(self): + if self.__cause__: + return "{} caused by {}".format( + super(ChainableException, self).__repr__(), self.__cause__.__repr__() + ) + else: + return super(ChainableException, self).__repr__() diff --git a/azure-iot-device/azure/iot/device/common/errors.py b/azure-iot-device/azure/iot/device/common/errors.py deleted file mode 100644 index 670e36f88..000000000 --- a/azure-iot-device/azure/iot/device/common/errors.py +++ /dev/null @@ -1,189 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - - -class OperationCancelledError(Exception): - """ - Operation was cancelled. - """ - - pass - - -class ConnectionFailedError(Exception): - """ - Connection failed to be established - """ - - pass - - -class ConnectionDroppedError(Exception): - """ - Previously established connection was dropped - """ - - pass - - -class ArgumentError(Exception): - """ - Service returned 400 - """ - - pass - - -class UnauthorizedError(Exception): - """ - Authorization failed or service returned 401 - """ - - pass - - -class QuotaExceededError(Exception): - """ - Service returned 403 - """ - - pass - - -class NotFoundError(Exception): - """ - Service returned 404 - """ - - pass - - -class DeviceTimeoutError(Exception): - """ - Service returned 408 - """ - - # TODO: is this a method call error? If so, do we retry? - pass - - -class DeviceAlreadyExistsError(Exception): - """ - Service returned 409 - """ - - pass - - -class InvalidEtagError(Exception): - """ - Service returned 412 - """ - - pass - - -class MessageTooLargeError(Exception): - """ - Service returned 413 - """ - - pass - - -class ThrottlingError(Exception): - """ - Service returned 429 - """ - - pass - - -class InternalServiceError(Exception): - """ - Service returned 500 - """ - - pass - - -class BadDeviceResponseError(Exception): - """ - Service returned 502 - """ - - # TODO: is this a method invoke thing? - pass - - -class ServiceUnavailableError(Exception): - """ - Service returned 503 - """ - - pass - - -class TimeoutError(Exception): - """ - Operation timed out or service returned 504 - """ - - pass - - -class FailedStatusCodeError(Exception): - """ - Service returned unknown status code - """ - - pass - - -class ProtocolClientError(Exception): - """ - Error returned from protocol client library - """ - - pass - - -class PipelineError(Exception): - """ - Error returned from transport pipeline - """ - - pass - - -status_code_to_error = { - 400: ArgumentError, - 401: UnauthorizedError, - 403: QuotaExceededError, - 404: NotFoundError, - 408: DeviceTimeoutError, - 409: DeviceAlreadyExistsError, - 412: InvalidEtagError, - 413: MessageTooLargeError, - 429: ThrottlingError, - 500: InternalServiceError, - 502: BadDeviceResponseError, - 503: ServiceUnavailableError, - 504: TimeoutError, -} - - -def error_from_status_code(status_code, message=None): - """ - Return an Error object from a failed status code - - :param int status_code: Status code returned from failed operation - :returns: Error object - """ - if status_code in status_code_to_error: - return status_code_to_error[status_code](message) - else: - return FailedStatusCodeError(message) diff --git a/azure-iot-device/azure/iot/device/common/evented_callback.py b/azure-iot-device/azure/iot/device/common/evented_callback.py index 4b6a6af62..283beb274 100644 --- a/azure-iot-device/azure/iot/device/common/evented_callback.py +++ b/azure-iot-device/azure/iot/device/common/evented_callback.py @@ -6,6 +6,7 @@ import threading import logging import six +import traceback logger = logging.getLogger(__name__) @@ -31,7 +32,6 @@ class EventedCallback(object): def wrapping_callback(*args, **kwargs): if "error" in kwargs and kwargs["error"]: - logger.error("Callback called with error {}".format(kwargs["error"])) self.exception = kwargs["error"] elif return_arg_name: if return_arg_name in kwargs: @@ -44,10 +44,9 @@ class EventedCallback(object): ) if self.exception: - logger.error( - "Callback completed with error {}".format(self.exception), - exc_info=self.exception, - ) + # Do not use exc_info parameter on logger.error. This casuses pytest to save the traceback which saves stack frames which shows up as a leak + logger.error("Callback completed with error {}".format(self.exception)) + logger.error(traceback.format_exc()) else: logger.debug("Callback completed with result {}".format(self.result)) diff --git a/azure-iot-device/azure/iot/device/common/unhandled_exceptions.py b/azure-iot-device/azure/iot/device/common/handle_exceptions.py similarity index 52% rename from azure-iot-device/azure/iot/device/common/unhandled_exceptions.py rename to azure-iot-device/azure/iot/device/common/handle_exceptions.py index 2fef92b4b..f554a7c7c 100644 --- a/azure-iot-device/azure/iot/device/common/unhandled_exceptions.py +++ b/azure-iot-device/azure/iot/device/common/handle_exceptions.py @@ -4,11 +4,12 @@ # license information. # -------------------------------------------------------------------------- import logging +import traceback logger = logging.getLogger(__name__) -def exception_caught_in_background_thread(e): +def handle_background_exception(e): """ Function which handled exceptions that are caught in background thread. This is typically called from the callback thread inside the pipeline. These exceptions @@ -24,4 +25,32 @@ def exception_caught_in_background_thread(e): # @FUTURE: We should add a mechanism which allows applications to receive these # exceptions so they can respond accordingly - logger.error(msg="Exception caught in background thread. Unable to handle.", exc_info=e) + logger.error(msg="Exception caught in background thread. Unable to handle.") + logger.error(traceback.format_exception_only(type(e), e)) + + +def swallow_unraised_exception(e, log_msg=None, log_lvl="warning"): + """Swallow and log an exception object. + + Convenience function for logging, as exceptions can only be logged correctly from within a + except block. + + :param Exception e: Exception object to be swallowed. + :param str log_msg: Optional message to use when logging. + :param str log_lvl: The log level to use for logging. Default "warning". + """ + try: + raise e + except Exception: + if log_lvl == "warning": + logger.warning(log_msg) + logger.warning(traceback.format_exc()) + elif log_lvl == "error": + logger.error(log_msg) + logger.error(traceback.format_exc()) + elif log_lvl == "info": + logger.info(log_msg) + logger.info(traceback.format_exc()) + else: + logger.debug(log_msg) + logger.debug(traceback.format_exc()) diff --git a/azure-iot-device/azure/iot/device/common/http_transport.py b/azure-iot-device/azure/iot/device/common/http_transport.py new file mode 100644 index 000000000..aa6d7a3a3 --- /dev/null +++ b/azure-iot-device/azure/iot/device/common/http_transport.py @@ -0,0 +1,108 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import logging +import uuid +import threading +import json +import ssl +from . import transport_exceptions as exceptions +from .pipeline import pipeline_thread +from six.moves import http_client + +logger = logging.getLogger(__name__) + + +class HTTPTransport(object): + """ + A wrapper class that provides an implementation-agnostic HTTP interface. + """ + + def __init__(self, hostname, server_verification_cert=None, x509_cert=None, cipher=None): + """ + Constructor to instantiate an HTTP protocol wrapper. + + :param str hostname: Hostname or IP address of the remote host. + :param str server_verification_cert: Certificate which can be used to validate a server-side TLS connection (optional). + :param x509_cert: Certificate which can be used to authenticate connection to a server in lieu of a password (optional). + """ + self._hostname = hostname + self._server_verification_cert = server_verification_cert + self._x509_cert = x509_cert + self._ssl_context = self._create_ssl_context() + + def _create_ssl_context(self): + """ + This method creates the SSLContext object used to authenticate the connection. The generated context is used by the http_client and is necessary when authenticating using a self-signed X509 cert or trusted X509 cert + """ + logger.debug("creating a SSL context") + ssl_context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLSv1_2) + + if self._server_verification_cert: + ssl_context.load_verify_locations(cadata=self._server_verification_cert) + else: + ssl_context.load_default_certs() + ssl_context.verify_mode = ssl.CERT_REQUIRED + ssl_context.check_hostname = True + + if self._x509_cert is not None: + logger.debug("configuring SSL context with client-side certificate and key") + ssl_context.load_cert_chain( + self._x509_cert.certificate_file, + self._x509_cert.key_file, + self._x509_cert.pass_phrase, + ) + + return ssl_context + + @pipeline_thread.invoke_on_http_thread_nowait + def request(self, method, path, callback, body="", headers={}, query_params=""): + """ + This method creates a connection to a remote host, sends a request to that host, and then waits for and reads the response from that request. + + :param str method: The request method (e.g. "POST") + :param str path: The path for the URL + :param Function callback: The function that gets called when this operation is complete or has failed. The callback function must accept an error and a response dictionary, where the response dictionary contains a status code, a reason, and a response string. + :param str body: The body of the HTTP request to be sent following the headers. + :param dict headers: A dictionary that provides extra HTTP headers to be sent with the request. + :param str query_params: The optional query parameters to be appended at the end of the URL. + """ + # Sends a complete request to the server + logger.info("sending https request.") + try: + logger.debug("creating an https connection") + connection = http_client.HTTPSConnection(self._hostname, context=self._ssl_context) + logger.debug("connecting to host tcp socket") + connection.connect() + logger.debug("connection succeeded") + url = "https://{hostname}/{path}{query_params}".format( + hostname=self._hostname, + path=path, + query_params="?" + query_params if query_params else "", + ) + logger.debug("Sending Request to HTTP URL: {}".format(url)) + logger.debug("HTTP Headers: {}".format(headers)) + logger.debug("HTTP Body: {}".format(body)) + connection.request(method, url, body=body, headers=headers) + response = connection.getresponse() + status_code = response.status + reason = response.reason + response_string = response.read() + + logger.debug("response received") + logger.debug("closing connection to https host") + connection.close() + logger.debug("connection closed") + logger.info("https request sent, and response received.") + response_obj = {"status_code": status_code, "reason": reason, "resp": response_string} + callback(response=response_obj) + except Exception as e: + logger.error("Error in HTTP Transport: {}".format(e)) + callback( + error=exceptions.ProtocolClientError( + message="Unexpected HTTPS failure during connect", cause=e + ) + ) diff --git a/azure-iot-device/azure/iot/device/common/models/__init__.py b/azure-iot-device/azure/iot/device/common/models/__init__.py index e6147a3d0..ce68085ad 100644 --- a/azure-iot-device/azure/iot/device/common/models/__init__.py +++ b/azure-iot-device/azure/iot/device/common/models/__init__.py @@ -4,3 +4,4 @@ This package provides object models for use within the Azure Provisioning Device """ from .x509 import X509 +from .proxy_options import ProxyOptions diff --git a/azure-iot-device/azure/iot/device/common/models/proxy_options.py b/azure-iot-device/azure/iot/device/common/models/proxy_options.py new file mode 100644 index 000000000..fec68974f --- /dev/null +++ b/azure-iot-device/azure/iot/device/common/models/proxy_options.py @@ -0,0 +1,53 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +""" +This module represents proxy options to enable sending traffic through proxy servers. +""" + + +class ProxyOptions(object): + """ + A class containing various options to send traffic through proxy servers by enabling + proxying of MQTT connection. + """ + + def __init__( + self, proxy_type, proxy_addr, proxy_port, proxy_username=None, proxy_password=None + ): + """ + Initializer for proxy options. + :param proxy_type: The type of the proxy server. This can be one of three possible choices:socks.HTTP, socks.SOCKS4, or socks.SOCKS5 + :param proxy_addr: IP address or DNS name of proxy server + :param proxy_port: The port of the proxy server. Defaults to 1080 for socks and 8080 for http. + :param proxy_username: (optional) username for SOCKS5 proxy, or userid for SOCKS4 proxy.This parameter is ignored if an HTTP server is being used. + If it is not provided, authentication will not be used (servers may accept unauthenticated requests). + :param proxy_password: (optional) This parameter is valid only for SOCKS5 servers and specifies the respective password for the username provided. + """ + self._proxy_type = proxy_type + self._proxy_addr = proxy_addr + self._proxy_port = proxy_port + self._proxy_username = proxy_username + self._proxy_password = proxy_password + + @property + def proxy_type(self): + return self._proxy_type + + @property + def proxy_address(self): + return self._proxy_addr + + @property + def proxy_port(self): + return self._proxy_port + + @property + def proxy_username(self): + return self._proxy_username + + @property + def proxy_password(self): + return self._proxy_password diff --git a/azure-iot-device/azure/iot/device/common/mqtt_transport.py b/azure-iot-device/azure/iot/device/common/mqtt_transport.py index 042644854..19e1300f8 100644 --- a/azure-iot-device/azure/iot/device/common/mqtt_transport.py +++ b/azure-iot-device/azure/iot/device/common/mqtt_transport.py @@ -7,61 +7,74 @@ import paho.mqtt.client as mqtt import logging import ssl +import sys import threading import traceback -from . import errors +import weakref +import socket +from . import transport_exceptions as exceptions +import socks logger = logging.getLogger(__name__) -# mapping of Paho conack rc codes to Error object classes -paho_conack_rc_to_error = { - mqtt.CONNACK_REFUSED_PROTOCOL_VERSION: errors.ProtocolClientError, - mqtt.CONNACK_REFUSED_IDENTIFIER_REJECTED: errors.ProtocolClientError, - mqtt.CONNACK_REFUSED_SERVER_UNAVAILABLE: errors.ConnectionFailedError, - mqtt.CONNACK_REFUSED_BAD_USERNAME_PASSWORD: errors.UnauthorizedError, - mqtt.CONNACK_REFUSED_NOT_AUTHORIZED: errors.UnauthorizedError, +# Mapping of Paho CONNACK rc codes to Error object classes +# Used for connection callbacks +paho_connack_rc_to_error = { + mqtt.CONNACK_REFUSED_PROTOCOL_VERSION: exceptions.ProtocolClientError, + mqtt.CONNACK_REFUSED_IDENTIFIER_REJECTED: exceptions.ProtocolClientError, + mqtt.CONNACK_REFUSED_SERVER_UNAVAILABLE: exceptions.ConnectionFailedError, + mqtt.CONNACK_REFUSED_BAD_USERNAME_PASSWORD: exceptions.UnauthorizedError, + mqtt.CONNACK_REFUSED_NOT_AUTHORIZED: exceptions.UnauthorizedError, } -# mapping of Paho rc codes to Error object classes +# Mapping of Paho rc codes to Error object classes +# Used for responses to Paho APIs and non-connection callbacks paho_rc_to_error = { - mqtt.MQTT_ERR_NOMEM: errors.ProtocolClientError, - mqtt.MQTT_ERR_PROTOCOL: errors.ProtocolClientError, - mqtt.MQTT_ERR_INVAL: errors.ArgumentError, - mqtt.MQTT_ERR_NO_CONN: errors.ConnectionDroppedError, - mqtt.MQTT_ERR_CONN_REFUSED: errors.ConnectionFailedError, - mqtt.MQTT_ERR_NOT_FOUND: errors.ConnectionFailedError, - mqtt.MQTT_ERR_CONN_LOST: errors.ConnectionDroppedError, - mqtt.MQTT_ERR_TLS: errors.UnauthorizedError, - mqtt.MQTT_ERR_PAYLOAD_SIZE: errors.ProtocolClientError, - mqtt.MQTT_ERR_NOT_SUPPORTED: errors.ProtocolClientError, - mqtt.MQTT_ERR_AUTH: errors.UnauthorizedError, - mqtt.MQTT_ERR_ACL_DENIED: errors.UnauthorizedError, - mqtt.MQTT_ERR_UNKNOWN: errors.ProtocolClientError, - mqtt.MQTT_ERR_ERRNO: errors.ProtocolClientError, - mqtt.MQTT_ERR_QUEUE_SIZE: errors.ProtocolClientError, + mqtt.MQTT_ERR_NOMEM: exceptions.ProtocolClientError, + mqtt.MQTT_ERR_PROTOCOL: exceptions.ProtocolClientError, + mqtt.MQTT_ERR_INVAL: exceptions.ProtocolClientError, + mqtt.MQTT_ERR_NO_CONN: exceptions.ConnectionDroppedError, + mqtt.MQTT_ERR_CONN_REFUSED: exceptions.ConnectionFailedError, + mqtt.MQTT_ERR_NOT_FOUND: exceptions.ConnectionFailedError, + mqtt.MQTT_ERR_CONN_LOST: exceptions.ConnectionDroppedError, + mqtt.MQTT_ERR_TLS: exceptions.UnauthorizedError, + mqtt.MQTT_ERR_PAYLOAD_SIZE: exceptions.ProtocolClientError, + mqtt.MQTT_ERR_NOT_SUPPORTED: exceptions.ProtocolClientError, + mqtt.MQTT_ERR_AUTH: exceptions.UnauthorizedError, + mqtt.MQTT_ERR_ACL_DENIED: exceptions.UnauthorizedError, + mqtt.MQTT_ERR_UNKNOWN: exceptions.ProtocolClientError, + mqtt.MQTT_ERR_ERRNO: exceptions.ProtocolClientError, + mqtt.MQTT_ERR_QUEUE_SIZE: exceptions.ProtocolClientError, } +# Default keepalive. Paho sends a PINGREQ using this interval +# to make sure the connection is still open. +DEFAULT_KEEPALIVE = 60 -def _create_error_from_conack_rc_code(rc): + +def _create_error_from_connack_rc_code(rc): """ - Given a paho CONACK rc code, return an Exception that can be raised + Given a paho CONNACK rc code, return an Exception that can be raised """ message = mqtt.connack_string(rc) - if rc in paho_conack_rc_to_error: - return paho_conack_rc_to_error[rc](message) + if rc in paho_connack_rc_to_error: + return paho_connack_rc_to_error[rc](message) else: - return errors.ProtocolClientError("Unknown CONACK rc={}".format(rc)) + return exceptions.ProtocolClientError("Unknown CONNACK rc={}".format(rc)) def _create_error_from_rc_code(rc): """ Given a paho rc code, return an Exception that can be raised """ - message = mqtt.error_string(rc) - if rc in paho_rc_to_error: + if rc == 1: + # Paho returns rc=1 to mean "something went wrong. stop". We manually translate this to a ConnectionDroppedError. + return exceptions.ConnectionDroppedError("Paho returned rc==1") + elif rc in paho_rc_to_error: + message = mqtt.error_string(rc) return paho_rc_to_error[rc](message) else: - return errors.ProtocolClientError("Unknown CONACK rc={}".format(rc)) + return exceptions.ProtocolClientError("Unknown CONNACK rc=={}".format(rc)) class MQTTTransport(object): @@ -78,21 +91,37 @@ class MQTTTransport(object): :type on_mqtt_connection_failure_handler: Function """ - def __init__(self, client_id, hostname, username, ca_cert=None, x509_cert=None): + def __init__( + self, + client_id, + hostname, + username, + server_verification_cert=None, + x509_cert=None, + websockets=False, + cipher=None, + proxy_options=None, + ): """ Constructor to instantiate an MQTT protocol wrapper. :param str client_id: The id of the client connecting to the broker. :param str hostname: Hostname or IP address of the remote broker. :param str username: Username for login to the remote broker. - :param str ca_cert: Certificate which can be used to validate a server-side TLS connection (optional). + :param str server_verification_cert: Certificate which can be used to validate a server-side TLS connection (optional). :param x509_cert: Certificate which can be used to authenticate connection to a server in lieu of a password (optional). + :param bool websockets: Indicates whether or not to enable a websockets connection in the Transport. + :param str cipher: Cipher string in OpenSSL cipher list format + :param proxy_options: Options for sending traffic through proxy servers. """ self._client_id = client_id self._hostname = hostname self._username = username self._mqtt_client = None - self._ca_cert = ca_cert + self._server_verification_cert = server_verification_cert self._x509_cert = x509_cert + self._websockets = websockets + self._cipher = cipher + self._proxy_options = proxy_options self.on_mqtt_connected_handler = None self.on_mqtt_disconnected_handler = None @@ -109,25 +138,53 @@ class MQTTTransport(object): """ logger.info("creating mqtt client") - # Instantiate client - mqtt_client = mqtt.Client( - client_id=self._client_id, clean_session=False, protocol=mqtt.MQTTv311 - ) + # Instaniate the client + if self._websockets: + logger.info("Creating client for connecting using MQTT over websockets") + mqtt_client = mqtt.Client( + client_id=self._client_id, + clean_session=False, + protocol=mqtt.MQTTv311, + transport="websockets", + ) + mqtt_client.ws_set_options(path="/$iothub/websocket") + else: + logger.info("Creating client for connecting using MQTT over TCP") + mqtt_client = mqtt.Client( + client_id=self._client_id, clean_session=False, protocol=mqtt.MQTTv311 + ) + + if self._proxy_options: + mqtt_client.proxy_set( + proxy_type=self._proxy_options.proxy_type, + proxy_addr=self._proxy_options.proxy_address, + proxy_port=self._proxy_options.proxy_port, + proxy_username=self._proxy_options.proxy_username, + proxy_password=self._proxy_options.proxy_password, + ) + mqtt_client.enable_logger(logging.getLogger("paho")) # Configure TLS/SSL ssl_context = self._create_ssl_context() mqtt_client.tls_set_context(context=ssl_context) - # Set event handlers + # Set event handlers. Use weak references back into this object to prevent + # leaks on Python 2.7. See callable_weak_method.py and PEP 442 for explanation. + # + # We don't use the CallableWeakMethod object here because these handlers + # are not methods. + self_weakref = weakref.ref(self) + def on_connect(client, userdata, flags, rc): + this = self_weakref() logger.info("connected with result code: {}".format(rc)) - if rc: - if self.on_mqtt_connection_failure_handler: + if rc: # i.e. if there is an error + if this.on_mqtt_connection_failure_handler: try: - self.on_mqtt_connection_failure_handler( - _create_error_from_conack_rc_code(rc) + this.on_mqtt_connection_failure_handler( + _create_error_from_connack_rc_code(rc) ) except Exception: logger.error("Unexpected error calling on_mqtt_connection_failure_handler") @@ -136,9 +193,9 @@ class MQTTTransport(object): logger.warning( "connection failed, but no on_mqtt_connection_failure_handler handler callback provided" ) - elif self.on_mqtt_connected_handler: + elif this.on_mqtt_connected_handler: try: - self.on_mqtt_connected_handler() + this.on_mqtt_connected_handler() except Exception: logger.error("Unexpected error calling on_mqtt_connected_handler") logger.error(traceback.format_exc()) @@ -146,15 +203,18 @@ class MQTTTransport(object): logger.warning("No event handler callback set for on_mqtt_connected_handler") def on_disconnect(client, userdata, rc): + this = self_weakref() logger.info("disconnected with result code: {}".format(rc)) cause = None - if rc: + if rc: # i.e. if there is an error + logger.debug("".join(traceback.format_stack())) cause = _create_error_from_rc_code(rc) + this._stop_automatic_reconnect() - if self.on_mqtt_disconnected_handler: + if this.on_mqtt_disconnected_handler: try: - self.on_mqtt_disconnected_handler(cause) + this.on_mqtt_disconnected_handler(cause) except Exception: logger.error("Unexpected error calling on_mqtt_disconnected_handler") logger.error(traceback.format_exc()) @@ -162,29 +222,33 @@ class MQTTTransport(object): logger.warning("No event handler callback set for on_mqtt_disconnected_handler") def on_subscribe(client, userdata, mid, granted_qos): + this = self_weakref() logger.info("suback received for {}".format(mid)) # subscribe failures are returned from the subscribe() call. This is just # a notification that a SUBACK was received, so there is no failure case here - self._op_manager.complete_operation(mid) + this._op_manager.complete_operation(mid) def on_unsubscribe(client, userdata, mid): + this = self_weakref() logger.info("UNSUBACK received for {}".format(mid)) # unsubscribe failures are returned from the unsubscribe() call. This is just # a notification that a SUBACK was received, so there is no failure case here - self._op_manager.complete_operation(mid) + this._op_manager.complete_operation(mid) def on_publish(client, userdata, mid): + this = self_weakref() logger.info("payload published for {}".format(mid)) # publish failures are returned from the publish() call. This is just # a notification that a PUBACK was received, so there is no failure case here - self._op_manager.complete_operation(mid) + this._op_manager.complete_operation(mid) def on_message(client, userdata, mqtt_message): + this = self_weakref() logger.info("message received on {}".format(mqtt_message.topic)) - if self.on_mqtt_message_received_handler: + if this.on_mqtt_message_received_handler: try: - self.on_mqtt_message_received_handler(mqtt_message.topic, mqtt_message.payload) + this.on_mqtt_message_received_handler(mqtt_message.topic, mqtt_message.payload) except Exception: logger.error("Unexpected error calling on_mqtt_message_received_handler") logger.error(traceback.format_exc()) @@ -203,6 +267,40 @@ class MQTTTransport(object): logger.debug("Created MQTT protocol client, assigned callbacks") return mqtt_client + def _stop_automatic_reconnect(self): + """ + After disconnecting because of an error, Paho will attempt to reconnect (some of the time -- + this isn't 100% reliable). We don't want Paho to reconnect because we want to control the + timing of the reconnect, so we force the connection closed. + + We are relying on intimite knowledge of Paho behavior here. If this becomes a problem, + it may be necessary to write our own Paho thread and stop using thread_start()/thread_stop(). + This is certainly supported by Paho, but the thread that Paho provides works well enough + (so far) and making our own would be more complex than is currently justified. + """ + + logger.info("Forcing paho disconnect to prevent it from automatically reconnecting") + + # Note: We are calling this inside our on_disconnect() handler, so we are inside the + # Paho thread at this point. This is perfectly valid. Comments in Paho's client.py + # loop_forever() function recomment calling disconnect() from a callback to exit the + # Paho thread/loop. + + self._mqtt_client.disconnect() + + # Calling disconnect() isn't enough. We also need to call loop_stop to make sure + # Paho is as clean as possible. Our call to disconnect() above is enough to stop the + # loop and exit the tread, but the call to loop_stop() is necessary to complete the cleanup. + + self._mqtt_client.loop_stop() + + # Finally, because of a bug in Paho, we need to null out the _thread pointer. This + # is necessary because the code that sets _thread to None only gets called if you + # call loop_stop from an external thread (and we're still inside the Paho thread here). + + self._mqtt_client._thread = None + logger.debug("Done forcing paho disconnect") + def _create_ssl_context(self): """ This method creates the SSLContext object used by Paho to authenticate the connection. @@ -210,12 +308,17 @@ class MQTTTransport(object): logger.debug("creating a SSL context") ssl_context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLSv1_2) - if self._ca_cert: - ssl_context.load_verify_locations(cadata=self._ca_cert) + if self._server_verification_cert: + ssl_context.load_verify_locations(cadata=self._server_verification_cert) else: ssl_context.load_default_certs() - ssl_context.verify_mode = ssl.CERT_REQUIRED - ssl_context.check_hostname = True + + if self._cipher: + try: + ssl_context.set_ciphers(self._cipher) + except ssl.SSLError as e: + # TODO: custom error with more detail? + raise e if self._x509_cert is not None: logger.debug("configuring SSL context with client-side certificate and key") @@ -225,6 +328,9 @@ class MQTTTransport(object): self._x509_cert.pass_phrase, ) + ssl_context.verify_mode = ssl.CERT_REQUIRED + ssl_context.check_hostname = True + return ssl_context def connect(self, password=None): @@ -235,45 +341,118 @@ class MQTTTransport(object): The password is not required if the transport was instantiated with an x509 certificate. + If MQTT connection has been proxied, connection will take a bit longer to allow negotiation + with the proxy server. Any errors in the proxy connection process will trigger exceptions + :param str password: The password for connecting with the MQTT broker (Optional). + + :raises: ConnectionFailedError if connection could not be established. + :raises: ConnectionDroppedError if connection is dropped during execution. + :raises: UnauthorizedError if there is an error authenticating. + :raises: ProtocolClientError if there is some other client error. """ logger.info("connecting to mqtt broker") self._mqtt_client.username_pw_set(username=self._username, password=password) - rc = self._mqtt_client.connect(host=self._hostname, port=8883) + try: + if self._websockets: + logger.info("Connect using port 443 (websockets)") + rc = self._mqtt_client.connect( + host=self._hostname, port=443, keepalive=DEFAULT_KEEPALIVE + ) + else: + logger.info("Connect using port 8883 (TCP)") + rc = self._mqtt_client.connect( + host=self._hostname, port=8883, keepalive=DEFAULT_KEEPALIVE + ) + except socket.error as e: + # Only this type will raise a special error + # To stop it from retrying. + if ( + isinstance(e, ssl.SSLError) + and e.strerror is not None + and "CERTIFICATE_VERIFY_FAILED" in e.strerror + ): + raise exceptions.TlsExchangeAuthError(cause=e) + elif isinstance(e, socks.ProxyError): + if isinstance(e, socks.SOCKS5AuthError): + # TODO This is the only I felt like specializing + raise exceptions.UnauthorizedError(cause=e) + else: + raise exceptions.ProtocolProxyError(cause=e) + else: + # If the socket can't open (e.g. using iptables REJECT), we get a + # socket.error. Convert this into ConnectionFailedError so we can retry + raise exceptions.ConnectionFailedError(cause=e) + + except socks.ProxyError as pe: + if isinstance(pe, socks.SOCKS5AuthError): + raise exceptions.UnauthorizedError(cause=pe) + else: + raise exceptions.ProtocolProxyError(cause=pe) + except Exception as e: + raise exceptions.ProtocolClientError( + message="Unexpected Paho failure during connect", cause=e + ) logger.debug("_mqtt_client.connect returned rc={}".format(rc)) if rc: raise _create_error_from_rc_code(rc) self._mqtt_client.loop_start() - def reconnect(self, password=None): + def reauthorize_connection(self, password=None): """ - Reconnect to the MQTT broker, using username set at instantiation. + Reauthorize with the MQTT broker, using username set at instantiation. Connect should have previously been called in order to use this function. The password is not required if the transport was instantiated with an x509 certificate. - :param str password: The password for reconnecting with the MQTT broker (Optional). + :param str password: The password for reauthorizing with the MQTT broker (Optional). + + :raises: ConnectionFailedError if connection could not be established. + :raises: ConnectionDroppedError if connection is dropped during execution. + :raises: UnauthorizedError if there is an error authenticating. + :raises: ProtocolClientError if there is some other client error. """ - logger.info("reconnecting MQTT client") + logger.info("reauthorizing MQTT client") self._mqtt_client.username_pw_set(username=self._username, password=password) - rc = self._mqtt_client.reconnect() + try: + rc = self._mqtt_client.reconnect() + except Exception as e: + raise exceptions.ProtocolClientError( + message="Unexpected Paho failure during reconnect", cause=e + ) logger.debug("_mqtt_client.reconnect returned rc={}".format(rc)) if rc: + # This could result in ConnectionFailedError, ConnectionDroppedError, UnauthorizedError + # or ProtocolClientError raise _create_error_from_rc_code(rc) def disconnect(self): """ Disconnect from the MQTT broker. + + :raises: ProtocolClientError if there is some client error. """ logger.info("disconnecting MQTT client") - rc = self._mqtt_client.disconnect() + try: + rc = self._mqtt_client.disconnect() + except Exception as e: + raise exceptions.ProtocolClientError( + message="Unexpected Paho failure during disconnect", cause=e + ) logger.debug("_mqtt_client.disconnect returned rc={}".format(rc)) self._mqtt_client.loop_stop() if rc: - raise _create_error_from_rc_code(rc) + # This could result in ConnectionDroppedError or ProtocolClientError + err = _create_error_from_rc_code(rc) + # If we get a ConnectionDroppedError, swallow it, because we have successfully disconnected! + if type(err) is exceptions.ConnectionDroppedError: + logger.warning("Dropped connection while disconnecting - swallowing error") + pass + else: + raise err def subscribe(self, topic, qos=1, callback=None): """ @@ -283,14 +462,25 @@ class MQTTTransport(object): :param int qos: the desired quality of service level for the subscription. Defaults to 1. :param callback: A callback to be triggered upon completion (Optional). - :return: message ID for the subscribe request - :raises: ValueError if qos is not 0, 1 or 2 - :raises: ValueError if topic is None or has zero string length + :return: message ID for the subscribe request. + + :raises: ValueError if qos is not 0, 1 or 2. + :raises: ValueError if topic is None or has zero string length. + :raises: ConnectionDroppedError if connection is dropped during execution. + :raises: ProtocolClientError if there is some other client error. """ logger.info("subscribing to {} with qos {}".format(topic, qos)) - (rc, mid) = self._mqtt_client.subscribe(topic, qos=qos) + try: + (rc, mid) = self._mqtt_client.subscribe(topic, qos=qos) + except ValueError: + raise + except Exception as e: + raise exceptions.ProtocolClientError( + message="Unexpected Paho failure during subscribe", cause=e + ) logger.debug("_mqtt_client.subscribe returned rc={}".format(rc)) if rc: + # This could result in ConnectionDroppedError or ProtocolClientError raise _create_error_from_rc_code(rc) self._op_manager.establish_operation(mid, callback) @@ -301,12 +491,22 @@ class MQTTTransport(object): :param str topic: a single string which is the subscription topic to unsubscribe from. :param callback: A callback to be triggered upon completion (Optional). - :raises: ValueError if topic is None or has zero string length + :raises: ValueError if topic is None or has zero string length. + :raises: ConnectionDroppedError if connection is dropped during execution. + :raises: ProtocolClientError if there is some other client error. """ logger.info("unsubscribing from {}".format(topic)) - (rc, mid) = self._mqtt_client.unsubscribe(topic) + try: + (rc, mid) = self._mqtt_client.unsubscribe(topic) + except ValueError: + raise + except Exception as e: + raise exceptions.ProtocolClientError( + message="Unexpected Paho failure during unsubscribe", cause=e + ) logger.debug("_mqtt_client.unsubscribe returned rc={}".format(rc)) if rc: + # This could result in ConnectionDroppedError or ProtocolClientError raise _create_error_from_rc_code(rc) self._op_manager.establish_operation(mid, callback) @@ -315,7 +515,8 @@ class MQTTTransport(object): Send a message via the MQTT broker. :param str topic: topic: The topic that the message should be published on. - :param str payload: The actual message to send. + :param payload: The actual message to send. + :type payload: str, bytes, int, float or None :param int qos: the desired quality of service level for the subscription. Defaults to 1. :param callback: A callback to be triggered upon completion (Optional). @@ -323,11 +524,24 @@ class MQTTTransport(object): :raises: ValueError if topic is None or has zero string length :raises: ValueError if topic contains a wildcard ("+") :raises: ValueError if the length of the payload is greater than 268435455 bytes + :raises: TypeError if payload is not a valid type + :raises: ConnectionDroppedError if connection is dropped during execution. + :raises: ProtocolClientError if there is some other client error. """ logger.info("publishing on {}".format(topic)) - (rc, mid) = self._mqtt_client.publish(topic=topic, payload=payload, qos=qos) + try: + (rc, mid) = self._mqtt_client.publish(topic=topic, payload=payload, qos=qos) + except ValueError: + raise + except TypeError: + raise + except Exception as e: + raise exceptions.ProtocolClientError( + message="Unexpected Paho failure during publish", cause=e + ) logger.debug("_mqtt_client.publish returned rc={}".format(rc)) if rc: + # This could result in ConnectionDroppedError or ProtocolClientError raise _create_error_from_rc_code(rc) self._op_manager.establish_operation(mid, callback) @@ -385,7 +599,7 @@ class OperationManager(object): logger.error("Unexpected error calling callback for MID: {}".format(mid)) logger.error(traceback.format_exc()) else: - logger.warning("No callback for MID: {}".format(mid)) + logger.exception("No callback for MID: {}".format(mid)) def complete_operation(self, mid): """Complete an operation identified by MID and trigger the associated completion callback. diff --git a/azure-iot-device/azure/iot/device/common/pipeline/__init__.py b/azure-iot-device/azure/iot/device/common/pipeline/__init__.py index 47f9bfbf3..d28b96427 100644 --- a/azure-iot-device/azure/iot/device/common/pipeline/__init__.py +++ b/azure-iot-device/azure/iot/device/common/pipeline/__init__.py @@ -7,3 +7,4 @@ INTERNAL USAGE ONLY from .pipeline_events_base import PipelineEvent from .pipeline_ops_base import PipelineOperation from .pipeline_stages_base import PipelineStage +from .pipeline_exceptions import OperationCancelled diff --git a/azure-iot-device/azure/iot/device/common/pipeline/config.py b/azure-iot-device/azure/iot/device/common/pipeline/config.py new file mode 100644 index 000000000..7d116cded --- /dev/null +++ b/azure-iot-device/azure/iot/device/common/pipeline/config.py @@ -0,0 +1,47 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import logging +import six +import abc + +logger = logging.getLogger(__name__) + + +@six.add_metaclass(abc.ABCMeta) +class BasePipelineConfig(object): + """A base class for storing all configurations/options shared across the Azure IoT Python Device Client Library. + More specific configurations such as those that only apply to the IoT Hub Client will be found in the respective + config files. + """ + + def __init__(self, websockets=False, cipher="", proxy_options=None): + """Initializer for BasePipelineConfig + + :param bool websockets: Enabling/disabling websockets in MQTT. This feature is relevant + if a firewall blocks port 8883 from use. + :param cipher: Optional cipher suite(s) for TLS/SSL, as a string in + "OpenSSL cipher list format" or as a list of cipher suite strings. + :type cipher: str or list(str) + """ + self.websockets = websockets + self.cipher = self._sanitize_cipher(cipher) + self.proxy_options = proxy_options + + @staticmethod + def _sanitize_cipher(cipher): + """Sanitize the cipher input and convert to a string in OpenSSL list format + """ + if isinstance(cipher, list): + cipher = ":".join(cipher) + + if isinstance(cipher, str): + cipher = cipher.upper() + cipher = cipher.replace("_", "-") + else: + raise TypeError("Invalid type for 'cipher'") + + return cipher diff --git a/azure-iot-device/azure/iot/device/common/pipeline/operation_flow.py b/azure-iot-device/azure/iot/device/common/pipeline/operation_flow.py deleted file mode 100644 index b5b88ee2b..000000000 --- a/azure-iot-device/azure/iot/device/common/pipeline/operation_flow.py +++ /dev/null @@ -1,137 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import logging -import sys -from . import pipeline_thread -from azure.iot.device.common import unhandled_exceptions - -from six.moves import queue - -logger = logging.getLogger(__name__) - - -@pipeline_thread.runs_on_pipeline_thread -def delegate_to_different_op(stage, original_op, new_op): - """ - Continue an operation using a new operation. This means that the new operation - will be passed down the pipeline (starting at the next stage). When that new - operation completes, the original operation will also complete. In this way, - a stage can accept one type of operation and, effectively, change that operation - into a different type of operation before passing it to the next stage. - - This is useful when a generic operation (such as "enable feature") needs to be - converted into a more specific operation (such as "subscribe to mqtt topic"). - In that case, a stage's _execute_op function would call this function passing in - the original "enable feature" op and the new "subscribe to mqtt topic" - op. This function will pass the "subscribe" down. When the "subscribe" op - is completed, this function will cause the original op to complete. - - This function is only really useful if there is no data returned in the - new_op that that needs to be copied back into the original_op before - completing it. If data needs to be copied this way, some other method needs - to be used. (or a "copy data back" function needs to be added to this function - as an optional parameter.) - - :param PipelineStage stage: stage to delegate the operation to - :param PipelineOperation original_op: Operation that is being continued using a - different op. This is most likely the operation that is currently being handled - by the stage. This operation is not actually continued, in that it is not - actually passed down the pipeline. Instead, the original_op operation is - effectively paused while we wait for the new_op operation to complete. When - the new_op operation completes, the original_op operation will also be completed. - :param PipelineOperation new_op: Operation that is being passed down the pipeline - to effectively continue the work represented by original_op. This is most likely - a different type of operation that is able to accomplish the intention of the - original_op in a way that is more specific than the original_op. - """ - - logger.debug("{}({}): continuing with {} op".format(stage.name, original_op.name, new_op.name)) - - @pipeline_thread.runs_on_pipeline_thread - def new_op_complete(op): - logger.debug( - "{}({}): completing with result from {}".format( - stage.name, original_op.name, new_op.name - ) - ) - original_op.error = new_op.error - complete_op(stage, original_op) - - new_op.callback = new_op_complete - pass_op_to_next_stage(stage, new_op) - - -@pipeline_thread.runs_on_pipeline_thread -def pass_op_to_next_stage(stage, op): - """ - Helper function to continue a given operation by passing it to the next stage - in the pipeline. If there is no next stage in the pipeline, this function - will fail the operation and call complete_op to return the failure back up the - pipeline. If the operation is already in an error state, this function will - complete the operation in order to return that error to the caller. - - :param PipelineStage stage: stage that the operation is being passed from - :param PipelineOperation op: Operation which is being passed on - """ - if op.error: - logger.error("{}({}): op has error. completing.".format(stage.name, op.name)) - complete_op(stage, op) - elif not stage.next: - logger.error("{}({}): no next stage. completing with error".format(stage.name, op.name)) - op.error = NotImplementedError( - "{} not handled after {} stage with no next stage".format(op.name, stage.name) - ) - complete_op(stage, op) - else: - logger.debug("{}({}): passing to next stage.".format(stage.name, op.name)) - stage.next.run_op(op) - - -@pipeline_thread.runs_on_pipeline_thread -def complete_op(stage, op): - """ - Helper function to complete an operation by calling its callback function thus - returning the result of the operation back up the pipeline. This is perferred to - calling the operation's callback directly as it provides several layers of protection - (such as a try/except wrapper) which are strongly advised. - """ - if op.error: - logger.error("{}({}): completing with error {}".format(stage.name, op.name, op.error)) - else: - logger.debug("{}({}): completing without error".format(stage.name, op.name)) - - try: - op.callback(op) - except Exception as e: - _, e, _ = sys.exc_info() - logger.error( - msg="Unhandled error calling back inside {}.complete_op() after {} complete".format( - stage.name, op.name - ), - exc_info=e, - ) - unhandled_exceptions.exception_caught_in_background_thread(e) - - -@pipeline_thread.runs_on_pipeline_thread -def pass_event_to_previous_stage(stage, event): - """ - Helper function to pass an event to the previous stage of the pipeline. This is the default - behavior of events while traveling through the pipeline. They start somewhere (maybe the - bottom) and move up the pipeline until they're handled or until they error out. - """ - if stage.previous: - logger.debug( - "{}({}): pushing event up to {}".format(stage.name, event.name, stage.previous.name) - ) - stage.previous.handle_pipeline_event(event) - else: - logger.error("{}({}): Error: unhandled event".format(stage.name, event.name)) - error = NotImplementedError( - "{} unhandled at {} stage with no previous stage".format(event.name, stage.name) - ) - unhandled_exceptions.exception_caught_in_background_thread(error) diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_events_base.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_events_base.py index d0fbb465b..68772abac 100644 --- a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_events_base.py +++ b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_events_base.py @@ -33,34 +33,49 @@ class PipelineEvent(object): self.name = self.__class__.__name__ -class IotResponseEvent(PipelineEvent): +class ResponseEvent(PipelineEvent): """ - A PipelineEvent object which is the second part of an SendIotRequestAndWaitForResponseOperation operation - (the response). The SendIotRequestAndWaitForResponseOperation represents the common operation of sending + A PipelineEvent object which is the second part of an RequestAndResponseOperation operation + (the response). The RequestAndResponseOperation represents the common operation of sending a request to iothub with a request_id ($rid) value and waiting for a response with the same $rid value. This convention is used by both Twin and Provisioning features. The response represented by this event has not yet been matched to the corresponding - SendIotRequestOperation operation. That matching is done by the CoordinateRequestAndResponseStage - stage which takes the contents of this event and puts it into the SendIotRequestAndWaitForResponseOperation + RequestOperation operation. That matching is done by the CoordinateRequestAndResponseStage + stage which takes the contents of this event and puts it into the RequestAndResponseOperation operation with the matching $rid value. - :ivar status_code: The status code returned by the response. Any value under 300 is - considered success. - :type status_code: int - :ivar request_id: The request ID which will eventually be used to match a SendIotRequestOperation + :ivar request_id: The request ID which will eventually be used to match a RequestOperation operation to this event. - :type request: str + :type request_id: str + :ivar status_code: The status code returned by the response. Any value under 300 is + considered success. + :type status_code: int :ivar response_body: The body of the response. - :type request_body: str - - :ivar status_code: - :type status: int - :ivar respons_body: + :type response_body: str + :ivar retry_after: A retry interval value that was extracted from the topic. + :type retry_after: int """ - def __init__(self, request_id, status_code, response_body): - super(IotResponseEvent, self).__init__() + def __init__(self, request_id, status_code, response_body, retry_after=None): + super(ResponseEvent, self).__init__() self.request_id = request_id self.status_code = status_code self.response_body = response_body + self.retry_after = retry_after + + +class ConnectedEvent(PipelineEvent): + """ + A PipelineEvent object indicating a connection has been established. + """ + + pass + + +class DisconnectedEvent(PipelineEvent): + """ + A PipelineEvent object indicating a connection has been dropped. + """ + + pass diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_exceptions.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_exceptions.py new file mode 100644 index 000000000..7ff804178 --- /dev/null +++ b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_exceptions.py @@ -0,0 +1,40 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +"""This module defines exceptions that may be raised from a pipeline""" + +from azure.iot.device.common.chainable_exception import ChainableException + + +class PipelineException(ChainableException): + """Generic pipeline exception""" + + pass + + +class OperationCancelled(PipelineException): + """Operation was cancelled""" + + pass + + +class OperationError(PipelineException): + """Error while executing an Operation""" + + pass + + +class PipelineTimeoutError(PipelineException): + """ + Pipeline operation timed out + """ + + pass + + +class PipelineError(PipelineException): + """Error caused by incorrect pipeline configuration""" + + pass diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_base.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_base.py index a4ef79e36..0abaea62d 100644 --- a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_base.py +++ b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_base.py @@ -3,6 +3,14 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +import sys +import logging +import traceback +from . import pipeline_exceptions +from . import pipeline_thread +from azure.iot.device.common import handle_exceptions + +logger = logging.getLogger(__name__) class PipelineOperation(object): @@ -19,33 +27,194 @@ class PipelineOperation(object): :ivar name: The name of the operation. This is used primarily for logging :type name: str :ivar callback: The callback that is called when the operation is completed, either - successfully or with a failure. + successfully or with a failure. :type callback: Function :ivar needs_connection: This is an attribute that indicates whether a particular operation - requires a connection to operate. This is currently used by the EnsureConnectionStage - stage, but this functionality will be revamped shortly. + requires a connection to operate. This is currently used by the AutoConnectStage + stage, but this functionality will be revamped shortly. :type needs_connection: Boolean :ivar error: The presence of a value in the error attribute indicates that the operation failed, - absence of this value indicates that the operation either succeeded or hasn't been handled yet. + absence of this value indicates that the operation either succeeded or hasn't been handled yet. :type error: Error """ - def __init__(self, callback=None): + def __init__(self, callback): """ Initializer for PipelineOperation objects. :param Function callback: The function that gets called when this operation is complete or has - failed. The callback function must accept A PipelineOperation object which indicates - the specific operation which has completed or failed. + failed. The callback function must accept A PipelineOperation object which indicates + the specific operation which has completed or failed. """ if self.__class__ == PipelineOperation: raise TypeError( "Cannot instantiate PipelineOperation object. You need to use a derived class" ) self.name = self.__class__.__name__ - self.callback = callback + self.callback_stack = [] self.needs_connection = False - self.error = None + self.completed = False # Operation has been fully completed + self.completing = False # Operation is in the process of completing + self.error = None # Error associated with Operation completion + + self.add_callback(callback) + + def add_callback(self, callback): + """Adds a callback to the Operation that will be triggered upon Operation completion. + + When an Operation is completed, all callbacks will be resolved in LIFO order. + + Callbacks cannot be added to an already completed operation, or an operation that is + currently undergoing a completion process. + + :param callback: The callback to add to the operation. + + :raises: OperationError if the operation is already completed, or is in the process of + completing. + """ + if self.completed: + raise pipeline_exceptions.OperationError( + "{}: Attempting to add a callback to an already-completed operation!".format( + self.name + ) + ) + if self.completing: + raise pipeline_exceptions.OperationError( + "{}: Attempting to add a callback to a operation with completion in progress!".format( + self.name + ) + ) + else: + self.callback_stack.append(callback) + + @pipeline_thread.runs_on_pipeline_thread + def complete(self, error=None): + """ Complete the operation, and trigger all callbacks in LIFO order. + + The operation is completed successfully be default, or completed unsucessfully if an error + is provided. + + An operation that is already fully completed, or in the process of completion cannot be + completed again. + + This process can be halted if a callback for the operation invokes the .halt_completion() + method on this Operation. + + :param error: Optionally provide an Exception object indicating the error that caused + the completion. Providing an error indicates that the operation was unsucessful. + """ + if error: + logger.error("{}: completing with error {}".format(self.name, error)) + else: + logger.debug("{}: completing without error".format(self.name)) + + if self.completed or self.completing: + logger.error("{}: has already been completed!".format(self.name)) + e = pipeline_exceptions.OperationError( + "Attempting to complete an already-completed operation: {}".format(self.name) + ) + # This could happen in a foreground or background thread, so err on the side of caution + # and send it to the background handler. + handle_exceptions.handle_background_exception(e) + else: + # Operation is now in the process of completing + self.completing = True + self.error = error + + while self.callback_stack: + if not self.completing: + logger.debug("{}: Completion halted!".format(self.name)) + break + if self.completed: + # This block should never be reached - this is an invalid state. + # If this block is reached, there is a bug in the code. + logger.error( + "{}: Invalid State! Operation completed while resolving completion".format( + self.name + ) + ) + e = pipeline_exceptions.OperationError( + "Operation reached fully completed state while still resolving completion: {}".format( + self.name + ) + ) + handle_exceptions.handle_background_exception(e) + break + + callback = self.callback_stack.pop() + try: + callback(op=self, error=error) + except Exception as e: + logger.error( + "Unhandled error while triggering callback for {}".format(self.name) + ) + logger.error(traceback.format_exc()) + # This could happen in a foreground or background thread, so err on the side of caution + # and send it to the background handler. + handle_exceptions.handle_background_exception(e) + + if self.completing: + # Operation is now completed, no longer in the process of completing + self.completing = False + self.completed = True + + @pipeline_thread.runs_on_pipeline_thread + def halt_completion(self): + """Halt the completion of an operation that is currently undergoing a completion process + as a result of a call to .complete(). + + Completion cannot be halted if there is no currently ongoing completion process. The only + way to successfully invoke this method is from within a callback on the Operation in + question. + + This method will leave any yet-untriggered callbacks on the Operation to be triggered upon + a later completion. + + This method will clear any error associated with the currently ongoing completion process + from the Operation. + """ + if not self.completing: + logger.error("{}: is not currently in the process of completion!".format(self.name)) + e = pipeline_exceptions.OperationError( + "Attempting to halt completion of an operation not in the process of completion: {}".format( + self.name + ) + ) + handle_exceptions.handle_background_exception(e) + else: + logger.debug("{}: Halting completion...".format(self.name)) + self.completing = False + self.error = None + + @pipeline_thread.runs_on_pipeline_thread + def spawn_worker_op(self, worker_op_type, **kwargs): + """Create and return a new operation, which, when completed, will complete the operation + it was spawned from. + + :param worker_op_type: The type (class) of the new worker operation. + :param **kwargs: The arguments to instantiate the new worker operation with. Note that a + callback is not required, but if provided, will be triggered prior to completing the + operation that spawned the worker operation. + + :returns: A new worker operation of the type specified in the worker_op_type parameter. + """ + logger.debug("{}: creating worker op of type {}".format(self.name, worker_op_type.__name__)) + + @pipeline_thread.runs_on_pipeline_thread + def on_worker_op_complete(op, error): + logger.debug("{}: Worker op ({}) has been completed".format(self.name, op.name)) + self.complete(error=error) + + if "callback" in kwargs: + provided_callback = kwargs["callback"] + kwargs["callback"] = on_worker_op_complete + worker_op = worker_op_type(**kwargs) + worker_op.add_callback(provided_callback) + else: + kwargs["callback"] = on_worker_op_complete + worker_op = worker_op_type(**kwargs) + + return worker_op class ConnectOperation(PipelineOperation): @@ -57,17 +226,19 @@ class ConnectOperation(PipelineOperation): Even though this is an base operation, it will most likely be handled by a more specific stage (such as an IoTHub or MQTT stage). """ - pass + def __init__(self, callback): + self.retry_timer = None + super(ConnectOperation, self).__init__(callback) -class ReconnectOperation(PipelineOperation): +class ReauthorizeConnectionOperation(PipelineOperation): """ - A PipelineOperation object which tells the pipeline to reconnect to whatever service it is connected to. + A PipelineOperation object which tells the pipeline to reauthorize the connection to whatever service it is connected to. - Clients will most-likely submit a Reconnect operation when some credential (such as a sas token) has changed and the protocol client + Clients will most-likely submit a ReauthorizeConnectionOperation when some credential (such as a sas token) has changed and the protocol client needs to re-establish the connection to refresh the credentials - This operation is in the group of base operations because reconnecting is a common operation that many clients might need to do. + This operation is in the group of base operations because reauthorizinging is a common operation that many clients might need to do. Even though this is an base operation, it will most likely be handled by a more specific stage (such as an IoTHub or MQTT stage). """ @@ -101,15 +272,15 @@ class EnableFeatureOperation(PipelineOperation): Even though this is an base operation, it will most likely be handled by a more specific stage (such as an IoTHub or MQTT stage). """ - def __init__(self, feature_name, callback=None): + def __init__(self, feature_name, callback): """ Initializer for EnableFeatureOperation objects. :param str feature_name: Name of the feature that is being enabled. The meaning of this - string is defined in the stage which handles this operation. + string is defined in the stage which handles this operation. :param Function callback: The function that gets called when this operation is complete or has - failed. The callback function must accept A PipelineOperation object which indicates - the specific operation which has completed or failed. + failed. The callback function must accept A PipelineOperation object which indicates + the specific operation which has completed or failed. """ super(EnableFeatureOperation, self).__init__(callback=callback) self.feature_name = feature_name @@ -129,15 +300,15 @@ class DisableFeatureOperation(PipelineOperation): Even though this is an base operation, it will most likely be handled by a more specific stage (such as an IoTHub or MQTT stage). """ - def __init__(self, feature_name, callback=None): + def __init__(self, feature_name, callback): """ Initializer for DisableFeatureOperation objects. :param str feature_name: Name of the feature that is being disabled. The meaning of this - string is defined in the stage which handles this operation. + string is defined in the stage which handles this operation. :param Function callback: The function that gets called when this operation is complete or has - failed. The callback function must accept A PipelineOperation object which indicates - the specific operation which has completed or failed. + failed. The callback function must accept A PipelineOperation object which indicates + the specific operation which has completed or failed. """ super(DisableFeatureOperation, self).__init__(callback=callback) self.feature_name = feature_name @@ -154,21 +325,21 @@ class UpdateSasTokenOperation(PipelineOperation): (such as IoTHub or MQTT stages). """ - def __init__(self, sas_token, callback=None): + def __init__(self, sas_token, callback): """ Initializer for UpdateSasTokenOperation objects. :param str sas_token: The token string which will be used to authenticate with whatever - service this pipeline connects with. + service this pipeline connects with. :param Function callback: The function that gets called when this operation is complete or has - failed. The callback function must accept A PipelineOperation object which indicates - the specific operation which has completed or failed. + failed. The callback function must accept A PipelineOperation object which indicates + the specific operation which has completed or failed. """ super(UpdateSasTokenOperation, self).__init__(callback=callback) self.sas_token = sas_token -class SendIotRequestAndWaitForResponseOperation(PipelineOperation): +class RequestAndResponseOperation(PipelineOperation): """ A PipelineOperation object which wraps the common operation of sending a request to iothub with a request_id ($rid) value and waiting for a response with the same $rid value. This convention is used by both Twin and Provisioning @@ -185,65 +356,80 @@ class SendIotRequestAndWaitForResponseOperation(PipelineOperation): :type status_code: int :ivar response_body: The body of the response. :type response_body: Undefined + :ivar query_params: Any query parameters that need to be sent with the request. + Example is the id of the operation as returned by the initial provisioning request. """ - def __init__(self, request_type, method, resource_location, request_body, callback=None): + def __init__( + self, request_type, method, resource_location, request_body, callback, query_params=None + ): """ - Initializer for SendIotRequestAndWaitForResponseOperation objects + Initializer for RequestAndResponseOperation objects :param str request_type: The type of request. This is a string which is used by protocol-specific stages to - generate the actual request. For example, if request_type is "twin", then the iothub_mqtt stage will convert - the request into an MQTT publish with topic that begins with $iothub/twin + generate the actual request. For example, if request_type is "twin", then the iothub_mqtt stage will convert + the request into an MQTT publish with topic that begins with $iothub/twin :param str method: The method for the request, in the REST sense of the word, such as "POST", "GET", etc. :param str resource_location: The resource that the method is acting on, in the REST sense of the word. - For twin request with method "GET", this is most likely the string "/" which retrieves the entire twin + For twin request with method "GET", this is most likely the string "/" which retrieves the entire twin :param request_body: The body of the request. This is a required field, and a single space can be used to denote - an empty body. + an empty body. :type request_body: Undefined :param Function callback: The function that gets called when this operation is complete or has - failed. The callback function must accept A PipelineOperation object which indicates - the specific operation which has completed or failed. + failed. The callback function must accept A PipelineOperation object which indicates + the specific operation which has completed or failed. """ - super(SendIotRequestAndWaitForResponseOperation, self).__init__(callback=callback) + super(RequestAndResponseOperation, self).__init__(callback=callback) self.request_type = request_type self.method = method self.resource_location = resource_location self.request_body = request_body self.status_code = None self.response_body = None + self.query_params = query_params -class SendIotRequestOperation(PipelineOperation): +class RequestOperation(PipelineOperation): """ - A PipelineOperation object which is the first part of an SendIotRequestAndWaitForResponseOperation operation (the request). The second - part of the SendIotRequestAndWaitForResponseOperation operation (the response) is returned via an IotResponseEvent event. + A PipelineOperation object which is the first part of an RequestAndResponseOperation operation (the request). The second + part of the RequestAndResponseOperation operation (the response) is returned via an ResponseEvent event. Even though this is an base operation, it will most likely be generated and also handled by more specifics stages (such as IoTHub or MQTT stages). """ def __init__( - self, request_type, method, resource_location, request_body, request_id, callback=None + self, + request_type, + method, + resource_location, + request_body, + request_id, + callback, + query_params=None, ): """ - Initializer for SendIotRequestOperation objects + Initializer for RequestOperation objects :param str request_type: The type of request. This is a string which is used by protocol-specific stages to - generate the actual request. For example, if request_type is "twin", then the iothub_mqtt stage will convert - the request into an MQTT publish with topic that begins with $iothub/twin + generate the actual request. For example, if request_type is "twin", then the iothub_mqtt stage will convert + the request into an MQTT publish with topic that begins with $iothub/twin :param str method: The method for the request, in the REST sense of the word, such as "POST", "GET", etc. :param str resource_location: The resource that the method is acting on, in the REST sense of the word. - For twin request with method "GET", this is most likely the string "/" which retrieves the entire twin + For twin request with method "GET", this is most likely the string "/" which retrieves the entire twin :param request_body: The body of the request. This is a required field, and a single space can be used to denote - an empty body. + an empty body. :type request_body: dict, str, int, float, bool, or None (JSON compatible values) :param Function callback: The function that gets called when this operation is complete or has - failed. The callback function must accept A PipelineOperation object which indicates - the specific operation which has completed or failed. + failed. The callback function must accept A PipelineOperation object which indicates + the specific operation which has completed or failed. + :type query_params: Any query parameters that need to be sent with the request. + Example is the id of the operation as returned by the initial provisioning request. """ - super(SendIotRequestOperation, self).__init__(callback=callback) + super(RequestOperation, self).__init__(callback=callback) self.method = method self.resource_location = resource_location self.request_type = request_type self.request_body = request_body self.request_id = request_id + self.query_params = query_params diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_http.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_http.py new file mode 100644 index 000000000..44d3e0176 --- /dev/null +++ b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_http.py @@ -0,0 +1,65 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from . import PipelineOperation + + +class SetHTTPConnectionArgsOperation(PipelineOperation): + """ + A PipelineOperation object which contains arguments used to connect to a server using the HTTP protocol. + + This operation is in the group of HTTP operations because its attributes are very specific to the HTTP protocol. + """ + + def __init__( + self, hostname, callback, server_verification_cert=None, client_cert=None, sas_token=None + ): + """ + Initializer for SetHTTPConnectionArgsOperation objects. + :param str hostname: The hostname of the HTTP server we will eventually connect to + :param str server_verification_cert: (Optional) The server verification certificate to use + if the HTTP server that we're going to connect to uses server-side TLS + :param X509 client_cert: (Optional) The x509 object containing a client certificate and key used to connect + to the HTTP service + :param str sas_token: The token string which will be used to authenticate with the service + :param Function callback: The function that gets called when this operation is complete or has failed. + The callback function must accept A PipelineOperation object which indicates the specific operation which + has completed or failed. + """ + super(SetHTTPConnectionArgsOperation, self).__init__(callback=callback) + self.hostname = hostname + self.server_verification_cert = server_verification_cert + self.client_cert = client_cert + self.sas_token = sas_token + + +class HTTPRequestAndResponseOperation(PipelineOperation): + """ + A PipelineOperation object which contains arguments used to connect to a server using the HTTP protocol. + + This operation is in the group of HTTP operations because its attributes are very specific to the HTTP protocol. + """ + + def __init__(self, method, path, headers, body, query_params, callback): + """ + Initializer for HTTPPublishOperation objects. + :param str method: The HTTP method used in the request + :param str path: The path to be used in the request url + :param dict headers: The headers to be used in the HTTP request + :param str body: The body to be provided with the HTTP request + :param str query_params: The query parameters to be used in the request url + :param Function callback: The function that gets called when this operation is complete or has failed. + The callback function must accept A PipelineOperation object which indicates the specific operation which + has completed or failed. + """ + super(HTTPRequestAndResponseOperation, self).__init__(callback=callback) + self.method = method + self.path = path + self.headers = headers + self.body = body + self.query_params = query_params + self.status_code = None + self.response_body = None + self.reason = None diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_mqtt.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_mqtt.py index a6106b07d..651c95dc9 100644 --- a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_mqtt.py +++ b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_mqtt.py @@ -18,10 +18,10 @@ class SetMQTTConnectionArgsOperation(PipelineOperation): client_id, hostname, username, - ca_cert=None, + callback, + server_verification_cert=None, client_cert=None, sas_token=None, - callback=None, ): """ Initializer for SetMQTTConnectionArgsOperation objects. @@ -29,8 +29,8 @@ class SetMQTTConnectionArgsOperation(PipelineOperation): :param str client_id: The client identifier to use when connecting to the MQTT server :param str hostname: The hostname of the MQTT server we will eventually connect to :param str username: The username to use when connecting to the MQTT server - :param str ca_cert: (Optional) The CA certificate to use if the MQTT server that we're going to - connect to uses server-side TLS + :param str server_verification_cert: (Optional) The server verification certificate to use + if the MQTT server that we're going to connect to uses server-side TLS :param X509 client_cert: (Optional) The x509 object containing a client certificate and key used to connect to the MQTT service :param str sas_token: The token string which will be used to authenticate with the service @@ -42,7 +42,7 @@ class SetMQTTConnectionArgsOperation(PipelineOperation): self.client_id = client_id self.hostname = hostname self.username = username - self.ca_cert = ca_cert + self.server_verification_cert = server_verification_cert self.client_cert = client_cert self.sas_token = sas_token @@ -54,7 +54,7 @@ class MQTTPublishOperation(PipelineOperation): This operation is in the group of MQTT operations because its attributes are very specific to the MQTT protocol. """ - def __init__(self, topic, payload, callback=None): + def __init__(self, topic, payload, callback): """ Initializer for MQTTPublishOperation objects. @@ -68,6 +68,7 @@ class MQTTPublishOperation(PipelineOperation): self.topic = topic self.payload = payload self.needs_connection = True + self.retry_timer = None class MQTTSubscribeOperation(PipelineOperation): @@ -77,7 +78,7 @@ class MQTTSubscribeOperation(PipelineOperation): This operation is in the group of MQTT operations because its attributes are very specific to the MQTT protocol. """ - def __init__(self, topic, callback=None): + def __init__(self, topic, callback): """ Initializer for MQTTSubscribeOperation objects. @@ -89,6 +90,8 @@ class MQTTSubscribeOperation(PipelineOperation): super(MQTTSubscribeOperation, self).__init__(callback=callback) self.topic = topic self.needs_connection = True + self.timeout_timer = None + self.retry_timer = None class MQTTUnsubscribeOperation(PipelineOperation): @@ -98,7 +101,7 @@ class MQTTUnsubscribeOperation(PipelineOperation): This operation is in the group of MQTT operations because its attributes are very specific to the MQTT protocol. """ - def __init__(self, topic, callback=None): + def __init__(self, topic, callback): """ Initializer for MQTTUnsubscribeOperation objects. @@ -110,3 +113,5 @@ class MQTTUnsubscribeOperation(PipelineOperation): super(MQTTUnsubscribeOperation, self).__init__(callback=callback) self.topic = topic self.needs_connection = True + self.timeout_timer = None + self.retry_timer = None diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_base.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_base.py index ca054a012..1a98a5892 100644 --- a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_base.py +++ b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_base.py @@ -7,13 +7,19 @@ import logging import abc import six +import sys +import time +import traceback import uuid +import weakref from six.moves import queue +import threading from . import pipeline_events_base -from . import pipeline_ops_base -from . import operation_flow +from . import pipeline_ops_base, pipeline_ops_mqtt from . import pipeline_thread -from azure.iot.device.common import unhandled_exceptions +from . import pipeline_exceptions +from azure.iot.device.common import handle_exceptions, transport_exceptions +from azure.iot.device.common.callable_weak_method import CallableWeakMethod logger = logging.getLogger(__name__) @@ -43,11 +49,11 @@ class PipelineStage(object): (use an auth provider) and converts it into something more generic (here is your device_id, etc, and use this SAS token when connecting). - An example of a generic-to-specific stage is IoTHubMQTTConverterStage which converts IoTHub operations + An example of a generic-to-specific stage is IoTHubMQTTTranslationStage which converts IoTHub operations (such as SendD2CMessageOperation) to MQTT operations (such as Publish). Each stage should also work in the broadest domain possible. For example a generic stage (say - "EnsureConnectionStage") that initiates a connection if any arbitrary operation needs a connection is more useful + "AutoConnectStage") that initiates a connection if any arbitrary operation needs a connection is more useful than having some MQTT-specific code that re-connects to the MQTT broker if the user calls Publish and there's no connection. @@ -81,7 +87,7 @@ class PipelineStage(object): def run_op(self, op): """ Run the given operation. This is the public function that outside callers would call to run an - operation. Derived classes should override the private _execute_op function to implement + operation. Derived classes should override the private _run_op function to implement stage-specific behavior. When run_op returns, that doesn't mean that the operation has executed to completion. Rather, it means that the pipeline has done something that will cause the operation to eventually execute to completion. That might mean that something was sent over @@ -92,29 +98,29 @@ class PipelineStage(object): :param PipelineOperation op: The operation to run. """ - logger.debug("{}({}): running".format(self.name, op.name)) try: - self._execute_op(op) + self._run_op(op) except Exception as e: # This path is ONLY for unexpected errors. Expected errors should cause a fail completion - # within ._execute_op() - logger.error(msg="Unexpected error in {}._execute_op() call".format(self), exc_info=e) - op.error = e - operation_flow.complete_op(self, op) + # within ._run_op() - @abc.abstractmethod - def _execute_op(self, op): + # Do not use exc_info parameter on logger.error. This casuses pytest to save the traceback which saves stack frames which shows up as a leak + logger.error(msg="Unexpected error in {}._run_op() call".format(self)) + logger.error(traceback.format_exc()) + op.complete(error=e) + + @pipeline_thread.runs_on_pipeline_thread + def _run_op(self, op): """ - Abstract method to run the actual operation. This function is implemented in derived classes - and performs the actual work that any operation expects. The default behavior for this function - should be to forward the event to the next stage using operation_flow.pass_op_to_next_stage for any - operations that a particular stage might not operate on. + Implementation of the stage-specific function of .run_op(). Override this method instead of + .run_op() in child classes in order to change how a stage behaves when running an operation. - See the description of the run_op method for more discussion on what it means to "run" an operation. + See the description of the .run_op() method for more discussion on what it means to "run" + an operation. :param PipelineOperation op: The operation to run. """ - pass + self.send_op_down(op) @pipeline_thread.runs_on_pipeline_thread def handle_pipeline_event(self, event): @@ -129,10 +135,10 @@ class PipelineStage(object): try: self._handle_pipeline_event(event) except Exception as e: - logger.error( - msg="Unexpected error in {}._handle_pipeline_event() call".format(self), exc_info=e - ) - unhandled_exceptions.exception_caught_in_background_thread(e) + # Do not use exc_info parameter on logger.error. This casuses pytest to save the traceback which saves stack frames which shows up as a leak + logger.error(msg="Unexpected error in {}._handle_pipeline_event() call".format(self)) + logger.error(traceback.format_exc()) + handle_exceptions.handle_background_exception(e) @pipeline_thread.runs_on_pipeline_thread def _handle_pipeline_event(self, event): @@ -143,23 +149,42 @@ class PipelineStage(object): :param PipelineEvent event: The event that is being passed back up the pipeline """ - operation_flow.pass_event_to_previous_stage(self, event) + self.send_event_up(event) @pipeline_thread.runs_on_pipeline_thread - def on_connected(self): + def send_op_down(self, op): """ - Called by lower layers when the protocol client connects + Helper function to continue a given operation by passing it to the next stage + in the pipeline. If there is no next stage in the pipeline, this function + will fail the operation and call complete_op to return the failure back up the + pipeline. + + :param PipelineOperation op: Operation which is being passed on """ - if self.previous: - self.previous.on_connected() + if not self.next: + logger.error("{}({}): no next stage. completing with error".format(self.name, op.name)) + error = pipeline_exceptions.PipelineError( + "{} not handled after {} stage with no next stage".format(op.name, self.name) + ) + op.complete(error=error) + else: + self.next.run_op(op) @pipeline_thread.runs_on_pipeline_thread - def on_disconnected(self): + def send_event_up(self, event): """ - Called by lower layers when the protocol client disconnects + Helper function to pass an event to the previous stage of the pipeline. This is the default + behavior of events while traveling through the pipeline. They start somewhere (maybe the + bottom) and move up the pipeline until they're handled or until they error out. """ if self.previous: - self.previous.on_disconnected() + self.previous.handle_pipeline_event(event) + else: + logger.error("{}({}): Error: unhandled event".format(self.name, event.name)) + error = pipeline_exceptions.PipelineError( + "{} unhandled at {} stage with no previous stage".format(event.name, self.name) + ) + handle_exceptions.handle_background_exception(error) class PipelineRootStage(PipelineStage): @@ -181,42 +206,36 @@ class PipelineRootStage(PipelineStage): :type on_disconnected_handler: Function """ - def __init__(self): + def __init__(self, pipeline_configuration): super(PipelineRootStage, self).__init__() self.on_pipeline_event_handler = None self.on_connected_handler = None self.on_disconnected_handler = None self.connected = False + self.pipeline_configuration = pipeline_configuration def run_op(self, op): - op.callback = pipeline_thread.invoke_on_callback_thread_nowait(op.callback) + # CT-TODO: make this more elegant + op.callback_stack[0] = pipeline_thread.invoke_on_callback_thread_nowait( + op.callback_stack[0] + ) pipeline_thread.invoke_on_pipeline_thread(super(PipelineRootStage, self).run_op)(op) - @pipeline_thread.runs_on_pipeline_thread - def _execute_op(self, op): - """ - run the operation. At the root, the only thing to do is to pass the operation - to the next stage. - - :param PipelineOperation op: Operation to run. - """ - operation_flow.pass_op_to_next_stage(self, op) - - def append_stage(self, new_next_stage): + def append_stage(self, new_stage): """ Add the next stage to the end of the pipeline. This is the function that callers use to build the pipeline by appending stages. This function returns the root of the pipeline so that calls to this function can be chained together. - :param PipelineStage new_next_stage: Stage to add to the end of the pipeline + :param PipelineStage new_stage: Stage to add to the end of the pipeline :returns: The root of the pipeline. """ old_tail = self while old_tail.next: old_tail = old_tail.next - old_tail.next = new_next_stage - new_next_stage.previous = old_tail - new_next_stage.pipeline_root = self + old_tail.next = new_stage + new_stage.previous = old_tail + new_stage.pipeline_root = self return self @pipeline_thread.runs_on_pipeline_thread @@ -229,42 +248,39 @@ class PipelineRootStage(PipelineStage): :param PipelineEvent event: Event to be handled, i.e. returned to the caller through the handle_pipeline_event (if provided). """ - if self.on_pipeline_event_handler: - pipeline_thread.invoke_on_callback_thread_nowait(self.on_pipeline_event_handler)(event) + if isinstance(event, pipeline_events_base.ConnectedEvent): + logger.debug( + "{}: ConnectedEvent received. Calling on_connected_handler".format(self.name) + ) + self.connected = True + if self.on_connected_handler: + pipeline_thread.invoke_on_callback_thread_nowait(self.on_connected_handler)() + + elif isinstance(event, pipeline_events_base.DisconnectedEvent): + logger.debug( + "{}: DisconnectedEvent received. Calling on_disconnected_handler".format(self.name) + ) + self.connected = False + if self.on_disconnected_handler: + pipeline_thread.invoke_on_callback_thread_nowait(self.on_disconnected_handler)() + else: - logger.warning("incoming pipeline event with no handler. dropping.") - - @pipeline_thread.runs_on_pipeline_thread - def on_connected(self): - logger.debug( - "{}: on_connected. on_connected_handler={}".format( - self.name, self.on_connected_handler - ) - ) - self.connected = True - if self.on_connected_handler: - pipeline_thread.invoke_on_callback_thread_nowait(self.on_connected_handler)() - - @pipeline_thread.runs_on_pipeline_thread - def on_disconnected(self): - logger.debug( - "{}: on_disconnected. on_disconnected_handler={}".format( - self.name, self.on_disconnected_handler - ) - ) - self.connected = False - if self.on_disconnected_handler: - pipeline_thread.invoke_on_callback_thread_nowait(self.on_disconnected_handler)() + if self.on_pipeline_event_handler: + pipeline_thread.invoke_on_callback_thread_nowait(self.on_pipeline_event_handler)( + event + ) + else: + logger.warning("incoming pipeline event with no handler. dropping.") -class EnsureConnectionStage(PipelineStage): +class AutoConnectStage(PipelineStage): """ This stage is responsible for ensuring that the protocol is connected when it needs to be connected. """ @pipeline_thread.runs_on_pipeline_thread - def _execute_op(self, op): + def _run_op(self, op): # Any operation that requires a connection can trigger a connection if # we're not connected. if op.needs_connection and not self.pipeline_root.connected: @@ -278,90 +294,95 @@ class EnsureConnectionStage(PipelineStage): # Finally, if this stage doesn't need to do anything else with this operation, # it just passes it down. else: - operation_flow.pass_op_to_next_stage(self, op) + self.send_op_down(op) @pipeline_thread.runs_on_pipeline_thread def _do_connect(self, op): """ Start connecting the transport in response to some operation """ + # Alias to avoid overload within the callback below + # CT-TODO: remove the need for this with better callback semantics + op_needs_complete = op + # function that gets called after we're connected. @pipeline_thread.runs_on_pipeline_thread - def on_connect_op_complete(op_connect): - if op_connect.error: + def on_connect_op_complete(op, error): + if error: logger.error( "{}({}): Connection failed. Completing with failure because of connection failure: {}".format( - self.name, op.name, op_connect.error + self.name, op_needs_complete.name, error ) ) - op.error = op_connect.error - operation_flow.complete_op(stage=self, op=op) + op_needs_complete.complete(error=error) else: logger.debug( - "{}({}): connection is complete. Continuing with op".format(self.name, op.name) + "{}({}): connection is complete. Continuing with op".format( + self.name, op_needs_complete.name + ) ) - operation_flow.pass_op_to_next_stage(stage=self, op=op) + self.send_op_down(op_needs_complete) # call down to the next stage to connect. logger.debug("{}({}): calling down with Connect operation".format(self.name, op.name)) - operation_flow.pass_op_to_next_stage( - self, pipeline_ops_base.ConnectOperation(callback=on_connect_op_complete) - ) + self.send_op_down(pipeline_ops_base.ConnectOperation(callback=on_connect_op_complete)) -class SerializeConnectOpsStage(PipelineStage): +class ConnectionLockStage(PipelineStage): """ - This stage is responsible for serializing connect, disconnect, and reconnect ops on + This stage is responsible for serializing connect, disconnect, and reauthorize ops on the pipeline, such that only a single one of these ops can go past this stage at a time. This way, we don't have to worry about cases like "what happens if we try to - disconnect if we're in the middle of reconnecting." This stage will wait for the - reconnect to complete before letting the disconnect past. + disconnect if we're in the middle of reauthorizing." This stage will wait for the + reauthorize to complete before letting the disconnect past. """ def __init__(self): - super(SerializeConnectOpsStage, self).__init__() + super(ConnectionLockStage, self).__init__() self.queue = queue.Queue() self.blocked = False @pipeline_thread.runs_on_pipeline_thread - def _execute_op(self, op): + def _run_op(self, op): + # If this stage is currently blocked (because we're waiting for a connection, etc, # to complete), we queue up all operations until after the connect completes. if self.blocked: logger.info( - "{}({}): pipeline is blocked waiting for a prior connect/disconnect/reconnect to complete. queueing.".format( + "{}({}): pipeline is blocked waiting for a prior connect/disconnect/reauthorize to complete. queueing.".format( self.name, op.name ) ) self.queue.put_nowait(op) elif isinstance(op, pipeline_ops_base.ConnectOperation) and self.pipeline_root.connected: - logger.info("{}({}): Transport is connected. Completing.".format(self.name, op.name)) - operation_flow.complete_op(stage=self, op=op) + logger.info( + "{}({}): Transport is already connected. Completing.".format(self.name, op.name) + ) + op.complete() elif ( isinstance(op, pipeline_ops_base.DisconnectOperation) and not self.pipeline_root.connected ): logger.info( - "{}({}): Transport is disconnected. Completing.".format(self.name, op.name) + "{}({}): Transport is already disconnected. Completing.".format(self.name, op.name) ) - operation_flow.complete_op(stage=self, op=op) + op.complete() elif ( isinstance(op, pipeline_ops_base.DisconnectOperation) or isinstance(op, pipeline_ops_base.ConnectOperation) - or isinstance(op, pipeline_ops_base.ReconnectOperation) + or isinstance(op, pipeline_ops_base.ReauthorizeConnectionOperation) ): self._block(op) - old_callback = op.callback @pipeline_thread.runs_on_pipeline_thread - def on_operation_complete(op): - if op.error: + def on_operation_complete(op, error): + if error: logger.error( "{}({}): op failed. Unblocking queue with error: {}".format( - self.name, op.name, op.error + self.name, op.name, error ) ) else: @@ -369,25 +390,18 @@ class SerializeConnectOpsStage(PipelineStage): "{}({}): op succeeded. Unblocking queue".format(self.name, op.name) ) - op.callback = old_callback - self._unblock(op, op.error) - logger.debug( - "{}({}): unblock is complete. completing op that caused unblock".format( - self.name, op.name - ) - ) - operation_flow.complete_op(stage=self, op=op) + self._unblock(op, error) - op.callback = on_operation_complete - operation_flow.pass_op_to_next_stage(stage=self, op=op) + op.add_callback(on_operation_complete) + self.send_op_down(op) else: - operation_flow.pass_op_to_next_stage(stage=self, op=op) + self.send_op_down(op) @pipeline_thread.runs_on_pipeline_thread def _block(self, op): """ - block this stage while we're waiting for the connect/disconnect/reconnect operation to complete. + block this stage while we're waiting for the connect/disconnect/reauthorize operation to complete. """ logger.debug("{}({}): blocking".format(self.name, op.name)) self.blocked = True @@ -395,7 +409,7 @@ class SerializeConnectOpsStage(PipelineStage): @pipeline_thread.runs_on_pipeline_thread def _unblock(self, op, error): """ - Unblock this stage after the connect/disconnect/reconnect operation is complete. This also means + Unblock this stage after the connect/disconnect/reauthorize operation is complete. This also means releasing all the operations that were queued up. """ logger.debug("{}({}): unblocking and releasing queued ops.".format(self.name, op.name)) @@ -418,21 +432,20 @@ class SerializeConnectOpsStage(PipelineStage): self.name, op.name, op_to_release.name ) ) - op_to_release.error = error - operation_flow.complete_op(self, op_to_release) + op_to_release.complete(error=error) else: logger.debug( "{}({}): releasing {} op.".format(self.name, op.name, op_to_release.name) ) - # call run_op directly here so operations go through this stage again (especiall connect/disconnect ops) + # call run_op directly here so operations go through this stage again (especially connect/disconnect ops) self.run_op(op_to_release) class CoordinateRequestAndResponseStage(PipelineStage): """ - Pipeline stage which is responsible for coordinating SendIotRequestAndWaitForResponseOperation operations. For each - SendIotRequestAndWaitForResponseOperation operation, this stage passes down a SendIotRequestOperation operation and waits for - an IotResponseEvent event. All other events are passed down unmodified. + Pipeline stage which is responsible for coordinating RequestAndResponseOperation operations. For each + RequestAndResponseOperation operation, this stage passes down a RequestOperation operation and waits for + an ResponseEvent event. All other events are passed down unmodified. """ def __init__(self): @@ -440,31 +453,38 @@ class CoordinateRequestAndResponseStage(PipelineStage): self.pending_responses = {} @pipeline_thread.runs_on_pipeline_thread - def _execute_op(self, op): - if isinstance(op, pipeline_ops_base.SendIotRequestAndWaitForResponseOperation): - # Convert SendIotRequestAndWaitForResponseOperation operation into a SendIotRequestOperation operation - # and send it down. A lower level will convert the SendIotRequestOperation into an - # actual protocol client operation. The SendIotRequestAndWaitForResponseOperation operation will be + def _run_op(self, op): + if isinstance(op, pipeline_ops_base.RequestAndResponseOperation): + # Convert RequestAndResponseOperation operation into a RequestOperation operation + # and send it down. A lower level will convert the RequestOperation into an + # actual protocol client operation. The RequestAndResponseOperation operation will be # completed when the corresponding IotResponse event is received in this stage. request_id = str(uuid.uuid4()) + # Alias to avoid overload within the callback below + # CT-TODO: remove the need for this with better callback semantics + op_waiting_for_response = op + @pipeline_thread.runs_on_pipeline_thread - def on_send_request_done(send_request_op): + def on_send_request_done(op, error): logger.debug( "{}({}): Finished sending {} request to {} resource {}".format( - self.name, op.name, op.request_type, op.method, op.resource_location + self.name, + op_waiting_for_response.name, + op_waiting_for_response.request_type, + op_waiting_for_response.method, + op_waiting_for_response.resource_location, ) ) - if send_request_op.error: - op.error = send_request_op.error + if error: logger.debug( "{}({}): removing request {} from pending list".format( - self.name, op.name, request_id + self.name, op_waiting_for_response.name, request_id ) ) del (self.pending_responses[request_id]) - operation_flow.complete_op(self, op) + op_waiting_for_response.complete(error=error) else: # request sent. Nothing to do except wait for the response pass @@ -480,23 +500,24 @@ class CoordinateRequestAndResponseStage(PipelineStage): ) self.pending_responses[request_id] = op - new_op = pipeline_ops_base.SendIotRequestOperation( + new_op = pipeline_ops_base.RequestOperation( method=op.method, resource_location=op.resource_location, request_body=op.request_body, request_id=request_id, request_type=op.request_type, callback=on_send_request_done, + query_params=op.query_params, ) - operation_flow.pass_op_to_next_stage(self, new_op) + self.send_op_down(new_op) else: - operation_flow.pass_op_to_next_stage(self, op) + self.send_op_down(op) @pipeline_thread.runs_on_pipeline_thread def _handle_pipeline_event(self, event): - if isinstance(event, pipeline_events_base.IotResponseEvent): - # match IotResponseEvent events to the saved dictionary of SendIotRequestAndWaitForResponseOperation + if isinstance(event, pipeline_events_base.ResponseEvent): + # match ResponseEvent events to the saved dictionary of RequestAndResponseOperation # operations which have not received responses yet. If the operation is found, # complete it. @@ -510,6 +531,7 @@ class CoordinateRequestAndResponseStage(PipelineStage): del (self.pending_responses[event.request_id]) op.status_code = event.status_code op.response_body = event.response_body + op.retry_after = event.retry_after logger.debug( "{}({}): Completing {} request to {} resource {} with status {}".format( self.name, @@ -520,7 +542,7 @@ class CoordinateRequestAndResponseStage(PipelineStage): op.status_code, ) ) - operation_flow.complete_op(self, op) + op.complete() else: logger.warning( "{}({}): request_id {} not found in pending list. Nothing to do. Dropping".format( @@ -528,4 +550,399 @@ class CoordinateRequestAndResponseStage(PipelineStage): ) ) else: - operation_flow.pass_event_to_previous_stage(self, event) + self.send_event_up(event) + + +class OpTimeoutStage(PipelineStage): + """ + The purpose of the timeout stage is to add timeout errors to select operations + + The timeout_intervals attribute contains a list of operations to track along with + their timeout values. Right now this list is hard-coded but the operations and + intervals will eventually become a parameter. + + For each operation that needs a timeout check, this stage will add a timer to + the operation. If the timer elapses, this stage will fail the operation with + a PipelineTimeoutError. The intention is that a higher stage will know what to + do with that error and act accordingly (either return the error to the user or + retry). + + This stage currently assumes that all timed out operation are just "lost". + It does not attempt to cancel the operation, as Paho doesn't have a way to + cancel an operation, and with QOS=1, sending a pub or sub twice is not + catastrophic. + + Also, as a long-term plan, the operations that need to be watched for timeout + will become an initialization parameter for this stage so that differet + instances of this stage can watch for timeouts on different operations. + This will be done because we want a lower-level timeout stage which can watch + for timeouts at the MQTT level, and we want a higher-level timeout stage which + can watch for timeouts at the iothub level. In this way, an MQTT operation that + times out can be retried as an MQTT operation and a higher-level IoTHub operation + which times out can be retried as an IoTHub operation (which might necessitate + redoing multiple MQTT operations). + """ + + def __init__(self): + super(OpTimeoutStage, self).__init__() + # use a fixed list and fixed intervals for now. Later, this info will come in + # as an init param or a retry poicy + self.timeout_intervals = { + pipeline_ops_mqtt.MQTTSubscribeOperation: 10, + pipeline_ops_mqtt.MQTTUnsubscribeOperation: 10, + } + + @pipeline_thread.runs_on_pipeline_thread + def _run_op(self, op): + if type(op) in self.timeout_intervals: + # Create a timer to watch for operation timeout on this op and attach it + # to the op. + self_weakref = weakref.ref(self) + + @pipeline_thread.invoke_on_pipeline_thread_nowait + def on_timeout(): + this = self_weakref() + logger.info("{}({}): returning timeout error".format(this.name, op.name)) + op.complete( + error=pipeline_exceptions.PipelineTimeoutError( + "operation timed out before protocol client could respond" + ) + ) + + logger.debug("{}({}): Creating timer".format(self.name, op.name)) + op.timeout_timer = threading.Timer(self.timeout_intervals[type(op)], on_timeout) + op.timeout_timer.start() + + # Send the op down, but intercept the return of the op so we can + # remove the timer when the op is done + op.add_callback(self._clear_timer) + logger.debug("{}({}): Sending down".format(self.name, op.name)) + self.send_op_down(op) + else: + self.send_op_down(op) + + @pipeline_thread.runs_on_pipeline_thread + def _clear_timer(self, op, error): + # When an op comes back, delete the timer and pass it right up. + if op.timeout_timer: + logger.debug("{}({}): Cancelling timer".format(self.name, op.name)) + op.timeout_timer.cancel() + op.timeout_timer = None + + +class RetryStage(PipelineStage): + """ + The purpose of the retry stage is to watch specific operations for specific + errors and retry the operations as appropriate. + + Unlike the OpTimeoutStage, this stage will never need to worry about cancelling + failed operations. When an operation is retried at this stage, it is already + considered "failed", so no cancellation needs to be done. + """ + + def __init__(self): + super(RetryStage, self).__init__() + # Retry intervals are hardcoded for now. Later, they come in as an + # init param, probably via retry policy. + self.retry_intervals = { + pipeline_ops_mqtt.MQTTSubscribeOperation: 20, + pipeline_ops_mqtt.MQTTUnsubscribeOperation: 20, + pipeline_ops_mqtt.MQTTPublishOperation: 20, + } + self.ops_waiting_to_retry = [] + + @pipeline_thread.runs_on_pipeline_thread + def _run_op(self, op): + """ + Send all ops down and intercept their return to "watch for retry" + """ + if self._should_watch_for_retry(op): + op.add_callback(self._do_retry_if_necessary) + self.send_op_down(op) + else: + self.send_op_down(op) + + @pipeline_thread.runs_on_pipeline_thread + def _should_watch_for_retry(self, op): + """ + Return True if this op needs to be watched for retry. This can be + called before the op runs. + """ + return type(op) in self.retry_intervals + + @pipeline_thread.runs_on_pipeline_thread + def _should_retry(self, op, error): + """ + Return True if this op needs to be retried. This must be called after + the op completes. + """ + if error: + if self._should_watch_for_retry(op): + if isinstance(error, pipeline_exceptions.PipelineTimeoutError): + return True + return False + + @pipeline_thread.runs_on_pipeline_thread + def _do_retry_if_necessary(self, op, error): + """ + Handler which gets called when operations are complete. This function + is where we check to see if a retry is necessary and set a "retry timer" + which can be used to send the op down again. + """ + if self._should_retry(op, error): + self_weakref = weakref.ref(self) + + @pipeline_thread.invoke_on_pipeline_thread_nowait + def do_retry(): + this = self_weakref() + logger.info("{}({}): retrying".format(this.name, op.name)) + op.retry_timer.cancel() + op.retry_timer = None + this.ops_waiting_to_retry.remove(op) + # Don't just send it down directly. Instead, go through run_op so we get + # retry functionality this time too + this.run_op(op) + + interval = self.retry_intervals[type(op)] + logger.warning( + "{}({}): Op needs retry with interval {} because of {}. Setting timer.".format( + self.name, op.name, interval, error + ) + ) + + # if we don't keep track of this op, it might get collected. + op.halt_completion() + self.ops_waiting_to_retry.append(op) + op.retry_timer = threading.Timer(self.retry_intervals[type(op)], do_retry) + op.retry_timer.start() + + else: + if op.retry_timer: + op.retry_timer.cancel() + op.retry_timer = None + + +transient_connect_errors = [ + pipeline_exceptions.OperationCancelled, + pipeline_exceptions.PipelineTimeoutError, + pipeline_exceptions.OperationError, + transport_exceptions.ConnectionFailedError, + transport_exceptions.ConnectionDroppedError, +] + + +class ReconnectState(object): + """ + Class which holds reconenct states as class variables. Created to make code that reads like an enum without using an enum. + + NEVER_CONNECTED: Ttransport has never been conencted. This state is necessary because some errors might be fatal or transient, + depending on wether the transport has been connceted. For example, a failed conenction is a transient error if we've connected + before, but it's fatal if we've never conencted. + + WAITING_TO_RECONNECT: This stage is in a waiting period before reconnecting. + + CONNECTED_OR_DISCONNECTED: The transport is either connected or disconencted. This stage doesn't really care which one, so + it doesn't keep track. + """ + + NEVER_CONNECTED = "NEVER_CONNECTED" + WAITING_TO_RECONNECT = "WAITING_TO_RECONNECT" + CONNECTED_OR_DISCONNECTED = "CONNECTED_OR_DISCONNECTED" + + +class ReconnectStage(PipelineStage): + def __init__(self): + super(ReconnectStage, self).__init__() + self.reconnect_timer = None + self.state = ReconnectState.NEVER_CONNECTED + # connect delay is hardcoded for now. Later, this comes from a retry policy + self.reconnect_delay = 10 + self.waiting_connect_ops = [] + + @pipeline_thread.runs_on_pipeline_thread + def _run_op(self, op): + if isinstance(op, pipeline_ops_base.ConnectOperation): + if self.state == ReconnectState.WAITING_TO_RECONNECT: + logger.info( + "{}({}): State is {}. Adding to wait list".format( + self.name, op.name, self.state + ) + ) + self.waiting_connect_ops.append(op) + else: + logger.info( + "{}({}): State is {}. Adding to wait list and sending new connect op down".format( + self.name, op.name, self.state + ) + ) + self.waiting_connect_ops.append(op) + self._send_new_connect_op_down() + + elif isinstance(op, pipeline_ops_base.DisconnectOperation): + if self.state == ReconnectState.WAITING_TO_RECONNECT: + logger.info( + "{}({}): State is {}. Canceling waiting ops and sending disconnect down.".format( + self.name, op.name, self.state + ) + ) + self._clear_reconnect_timer() + self._complete_waiting_connect_ops( + pipeline_exceptions.OperationCancelled("Explicit disconnect invoked") + ) + self.state = ReconnectState.CONNECTED_OR_DISCONNECTED + op.complete() + + else: + logger.info( + "{}({}): State is {}. Sending op down.".format(self.name, op.name, self.state) + ) + self.send_op_down(op) + + else: + self.send_op_down(op) + + @pipeline_thread.runs_on_pipeline_thread + def _handle_pipeline_event(self, event): + if isinstance(event, pipeline_events_base.DisconnectedEvent): + if self.pipeline_root.connected: + logger.info( + "{}({}): State is {}. Triggering reconnect timer".format( + self.name, event.name, self.state + ) + ) + self.state = ReconnectState.WAITING_TO_RECONNECT + self._start_reconnect_timer() + else: + logger.info( + "{}({}): State is {}. Doing nothing".format(self.name, event.name, self.state) + ) + + self.send_event_up(event) + + else: + self.send_event_up(event) + + @pipeline_thread.runs_on_pipeline_thread + def _send_new_connect_op_down(self): + self_weakref = weakref.ref(self) + + @pipeline_thread.runs_on_pipeline_thread + def on_connect_complete(op, error): + this = self_weakref() + if this: + if error: + if this.state == ReconnectState.NEVER_CONNECTED: + logger.info( + "{}({}): error on first connection. Not triggering reconnection".format( + this.name, op.name + ) + ) + this._complete_waiting_connect_ops(error) + elif type(error) in transient_connect_errors: + logger.info( + "{}({}): State is {}. Connect failed with transient error. Triggering reconnect timer".format( + self.name, op.name, self.state + ) + ) + self.state = ReconnectState.WAITING_TO_RECONNECT + self._start_reconnect_timer() + + elif this.state == ReconnectState.WAITING_TO_RECONNECT: + logger.info( + "{}({}): non-tranient error. Failing all waiting ops.n".format( + this.name, op.name + ) + ) + self.state = ReconnectState.CONNECTED_OR_DISCONNECTED + self._clear_reconnect_timer() + this._complete_waiting_connect_ops(error) + + else: + logger.info( + "{}({}): State is {}. Connection failed. Not triggering reconnection".format( + this.name, op.name, this.state + ) + ) + this._complete_waiting_connect_ops(error) + else: + logger.info( + "{}({}): State is {}. Connection succeeded".format( + this.name, op.name, this.state + ) + ) + self.state = ReconnectState.CONNECTED_OR_DISCONNECTED + self._clear_reconnect_timer() + self._complete_waiting_connect_ops() + + logger.info("{}: sending new connect op down".format(self.name)) + op = pipeline_ops_base.ConnectOperation(callback=on_connect_complete) + self.send_op_down(op) + + @pipeline_thread.runs_on_pipeline_thread + def _start_reconnect_timer(self): + """ + Set a timer to reconnect after some period of time + """ + logger.info("{}: State is {}. Starting reconnect timer".format(self.name, self.state)) + + self._clear_reconnect_timer() + + self_weakref = weakref.ref(self) + + @pipeline_thread.invoke_on_pipeline_thread_nowait + def on_reconnect_timer_expired(): + this = self_weakref() + this.reconnect_timer = None + if this.state == ReconnectState.WAITING_TO_RECONNECT: + logger.info( + "{}: State is {}. Reconnect timer expired. Sending connect op down".format( + this.name, this.state + ) + ) + this.state = ReconnectState.CONNECTED_OR_DISCONNECTED + this._send_new_connect_op_down() + else: + logger.info( + "{}: State is {}. Reconnect timer expired. Doing nothing".format( + this.name, this.state + ) + ) + + self.reconnect_timer = threading.Timer(self.reconnect_delay, on_reconnect_timer_expired) + self.reconnect_timer.start() + + @pipeline_thread.runs_on_pipeline_thread + def _clear_reconnect_timer(self): + """ + Clear any previous reconnect timer + """ + if self.reconnect_timer: + logger.info("{}: clearing reconnect timer".format(self.name)) + self.reconnect_timer.cancel() + self.reconnect_timer = None + + @pipeline_thread.runs_on_pipeline_thread + def _complete_waiting_connect_ops(self, error=None): + """ + A note of explanation: when we are waiting to reconnect, we need to keep a list of + all connect ops that come through here. We do this for 2 reasons: + + 1. We don't want to pass them down immediately because we want to honor the waiting + period. If we passed them down immediately, we'd try to reconnect immediately + instead of waiting until reconnect_timer fires. + + 2. When we're retrying, there are new ConnectOperation ops sent down regularly. + Any of the ops could be the one that succeeds. When that happens, we need a + way to to complete all of the ops that are patiently waiting for the connection. + + Right now, we only need to do this with ConnectOperation ops because these are the + only ops that need to wait because these are the only ops that cause a connection + to be established. Other ops pass through this stage, and might fail in later + stages, but that's OK. If they needed a connection, the AutoConnectStage before + this stage should be taking care of that. + """ + logger.info("{}: completing waiting ops with error={}".format(self.name, error)) + list_copy = self.waiting_connect_ops + self.waiting_connect_ops = [] + for op in list_copy: + op.complete(error) diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_http.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_http.py new file mode 100644 index 000000000..9cec9cfa5 --- /dev/null +++ b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_http.py @@ -0,0 +1,102 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import logging +import six +import traceback +import copy +from . import ( + pipeline_ops_base, + PipelineStage, + pipeline_ops_http, + pipeline_thread, + pipeline_exceptions, +) +from azure.iot.device.common.http_transport import HTTPTransport +from azure.iot.device.common import handle_exceptions, transport_exceptions +from azure.iot.device.common.callable_weak_method import CallableWeakMethod + +logger = logging.getLogger(__name__) + + +class HTTPTransportStage(PipelineStage): + """ + PipelineStage object which is responsible for interfacing with the HTTP protocol wrapper object. + This stage handles all HTTP operations that are not specific to IoT Hub. + """ + + def __init__(self): + super(HTTPTransportStage, self).__init__() + # The sas_token will be set when Connetion Args are received + self.sas_token = None + + # The transport will be instantiated when Connection Args are received + self.transport = None + + @pipeline_thread.runs_on_pipeline_thread + def _run_op(self, op): + if isinstance(op, pipeline_ops_http.SetHTTPConnectionArgsOperation): + # pipeline_ops_http.SetHTTPConenctionArgsOperation is used to create the HTTPTransport object and set all of it's properties. + logger.debug("{}({}): got connection args".format(self.name, op.name)) + self.sas_token = op.sas_token + self.transport = HTTPTransport( + hostname=op.hostname, + server_verification_cert=op.server_verification_cert, + x509_cert=op.client_cert, + ) + + self.pipeline_root.transport = self.transport + op.complete() + + elif isinstance(op, pipeline_ops_base.UpdateSasTokenOperation): + logger.debug("{}({}): saving sas token and completing".format(self.name, op.name)) + self.sas_token = op.sas_token + op.complete() + + elif isinstance(op, pipeline_ops_http.HTTPRequestAndResponseOperation): + # This will call down to the HTTP Transport with a request and also created a request callback. Because the HTTP Transport will run on the http transport thread, this call should be non-blocking to the pipline thread. + logger.debug( + "{}({}): Generating HTTP request and setting callback before completing.".format( + self.name, op.name + ) + ) + + @pipeline_thread.invoke_on_pipeline_thread_nowait + def on_request_completed(error=None, response=None): + if error: + logger.error( + "{}({}): Error passed to on_request_completed. Error={}".format( + self.name, op.name, error + ) + ) + op.complete(error=error) + else: + logger.debug( + "{}({}): Request completed. Completing op.".format(self.name, op.name) + ) + logger.debug("HTTP Response Status: {}".format(response["status_code"])) + logger.debug("HTTP Response: {}".format(response["resp"].decode("utf-8"))) + op.response_body = response["resp"] + op.status_code = response["status_code"] + op.reason = response["reason"] + op.complete() + + # A deepcopy is necessary here since otherwise the manipulation happening to http_headers will affect the op.headers, which would be an unintended side effect and not a good practice. + http_headers = copy.deepcopy(op.headers) + if self.sas_token: + http_headers["Authorization"] = self.sas_token + + self.transport.request( + method=op.method, + path=op.path, + headers=http_headers, + query_params=op.query_params, + body=op.body, + callback=on_request_completed, + ) + + else: + self.send_op_down(op) diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_mqtt.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_mqtt.py index 566a37cb8..406fdec59 100644 --- a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_mqtt.py +++ b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_mqtt.py @@ -6,16 +6,19 @@ import logging import six +import traceback from . import ( pipeline_ops_base, PipelineStage, pipeline_ops_mqtt, pipeline_events_mqtt, - operation_flow, pipeline_thread, + pipeline_exceptions, + pipeline_events_base, ) from azure.iot.device.common.mqtt_transport import MQTTTransport -from azure.iot.device.common import unhandled_exceptions, errors +from azure.iot.device.common import handle_exceptions, transport_exceptions +from azure.iot.device.common.callable_weak_method import CallableWeakMethod logger = logging.getLogger(__name__) @@ -27,66 +30,83 @@ class MQTTTransportStage(PipelineStage): is not in the MQTT group of operations, but can only be run at the protocol level. """ + def __init__(self): + super(MQTTTransportStage, self).__init__() + + # The sas_token will be set when Connetion Args are received + self.sas_token = None + + # The transport will be instantiated when Connection Args are received + self.transport = None + + self._pending_connection_op = None + @pipeline_thread.runs_on_pipeline_thread def _cancel_pending_connection_op(self): """ - Cancel any running connect, disconnect or reconnect op. Since our ability to "cancel" is fairly limited, + Cancel any running connect, disconnect or reauthorize_connection op. Since our ability to "cancel" is fairly limited, all this does (for now) is to fail the operation """ op = self._pending_connection_op if op: - # TODO: should this actually run a cancel call on the op? - op.error = errors.PipelineError( - "Cancelling because new ConnectOperation, DisconnectOperation, or ReconnectOperation was issued" - ) - operation_flow.complete_op(stage=self, op=op) + # NOTE: This code path should NOT execute in normal flow. There should never already be a pending + # connection op when another is added, due to the SerializeConnectOps stage. + # If this block does execute, there is a bug in the codebase. + error = pipeline_exceptions.OperationCancelled( + "Cancelling because new ConnectOperation, DisconnectOperation, or ReauthorizeConnectionOperation was issued" + ) # TODO: should this actually somehow cancel the operation? + op.complete(error=error) self._pending_connection_op = None @pipeline_thread.runs_on_pipeline_thread - def _execute_op(self, op): + def _run_op(self, op): if isinstance(op, pipeline_ops_mqtt.SetMQTTConnectionArgsOperation): # pipeline_ops_mqtt.SetMQTTConnectionArgsOperation is where we create our MQTTTransport object and set # all of its properties. logger.debug("{}({}): got connection args".format(self.name, op.name)) - self.hostname = op.hostname - self.username = op.username - self.client_id = op.client_id - self.ca_cert = op.ca_cert self.sas_token = op.sas_token - self.client_cert = op.client_cert - self.transport = MQTTTransport( - client_id=self.client_id, - hostname=self.hostname, - username=self.username, - ca_cert=self.ca_cert, - x509_cert=self.client_cert, + client_id=op.client_id, + hostname=op.hostname, + username=op.username, + server_verification_cert=op.server_verification_cert, + x509_cert=op.client_cert, + websockets=self.pipeline_root.pipeline_configuration.websockets, + cipher=self.pipeline_root.pipeline_configuration.cipher, + proxy_options=self.pipeline_root.pipeline_configuration.proxy_options, + ) + self.transport.on_mqtt_connected_handler = CallableWeakMethod( + self, "_on_mqtt_connected" + ) + self.transport.on_mqtt_connection_failure_handler = CallableWeakMethod( + self, "_on_mqtt_connection_failure" + ) + self.transport.on_mqtt_disconnected_handler = CallableWeakMethod( + self, "_on_mqtt_disconnected" + ) + self.transport.on_mqtt_message_received_handler = CallableWeakMethod( + self, "_on_mqtt_message_received" ) - self.transport.on_mqtt_connected_handler = self._on_mqtt_connected - self.transport.on_mqtt_connection_failure_handler = self._on_mqtt_connection_failure - self.transport.on_mqtt_disconnected_handler = self._on_mqtt_disconnected - self.transport.on_mqtt_message_received_handler = self._on_mqtt_message_received - # There can only be one pending connection operation (Connect, Reconnect, Disconnect) + # There can only be one pending connection operation (Connect, ReauthorizeConnection, Disconnect) # at a time. The existing one must be completed or canceled before a new one is set. # Currently, this means that if, say, a connect operation is the pending op and is executed - # but another connection op is begins by the time the CONACK is received, the original - # operation will be cancelled, but the CONACK for it will still be received, and complete the + # but another connection op is begins by the time the CONNACK is received, the original + # operation will be cancelled, but the CONNACK for it will still be received, and complete the # NEW operation. This is not desirable, but it is how things currently work. - # We are however, checking the type, so the CONACK from a cancelled Connect, cannot successfully + # We are however, checking the type, so the CONNACK from a cancelled Connect, cannot successfully # complete a Disconnect operation. self._pending_connection_op = None - self.pipeline_root.transport = self.transport - operation_flow.complete_op(self, op) + op.complete() elif isinstance(op, pipeline_ops_base.UpdateSasTokenOperation): logger.debug("{}({}): saving sas token and completing".format(self.name, op.name)) self.sas_token = op.sas_token - operation_flow.complete_op(self, op) + op.complete() elif isinstance(op, pipeline_ops_base.ConnectOperation): logger.info("{}({}): connecting".format(self.name, op.name)) @@ -96,24 +116,24 @@ class MQTTTransportStage(PipelineStage): try: self.transport.connect(password=self.sas_token) except Exception as e: - logger.error("transport.connect raised error", exc_info=True) + logger.error("transport.connect raised error") + logger.error(traceback.format_exc()) self._pending_connection_op = None - op.error = e - operation_flow.complete_op(self, op) + op.complete(error=e) - elif isinstance(op, pipeline_ops_base.ReconnectOperation): - logger.info("{}({}): reconnecting".format(self.name, op.name)) + elif isinstance(op, pipeline_ops_base.ReauthorizeConnectionOperation): + logger.info("{}({}): reauthorizing".format(self.name, op.name)) - # We set _active_connect_op here because a reconnect is the same as a connect for "active operation" tracking purposes. + # We set _active_connect_op here because reauthorizing the connection is the same as a connect for "active operation" tracking purposes. self._cancel_pending_connection_op() self._pending_connection_op = op try: - self.transport.reconnect(password=self.sas_token) + self.transport.reauthorize_connection(password=self.sas_token) except Exception as e: - logger.error("transport.reconnect raised error", exc_info=True) + logger.error("transport.reauthorize_connection raised error") + logger.error(traceback.format_exc()) self._pending_connection_op = None - op.error = e - operation_flow.complete_op(self, op) + op.complete(error=e) elif isinstance(op, pipeline_ops_base.DisconnectOperation): logger.info("{}({}): disconnecting".format(self.name, op.name)) @@ -123,10 +143,10 @@ class MQTTTransportStage(PipelineStage): try: self.transport.disconnect() except Exception as e: - logger.error("transport.disconnect raised error", exc_info=True) + logger.error("transport.disconnect raised error") + logger.error(traceback.format_exc()) self._pending_connection_op = None - op.error = e - operation_flow.complete_op(self, op) + op.complete(error=e) elif isinstance(op, pipeline_ops_mqtt.MQTTPublishOperation): logger.info("{}({}): publishing on {}".format(self.name, op.name, op.topic)) @@ -134,7 +154,7 @@ class MQTTTransportStage(PipelineStage): @pipeline_thread.invoke_on_pipeline_thread_nowait def on_published(): logger.debug("{}({}): PUBACK received. completing op.".format(self.name, op.name)) - operation_flow.complete_op(self, op) + op.complete() self.transport.publish(topic=op.topic, payload=op.payload, callback=on_published) @@ -144,7 +164,7 @@ class MQTTTransportStage(PipelineStage): @pipeline_thread.invoke_on_pipeline_thread_nowait def on_subscribed(): logger.debug("{}({}): SUBACK received. completing op.".format(self.name, op.name)) - operation_flow.complete_op(self, op) + op.complete() self.transport.subscribe(topic=op.topic, callback=on_subscribed) @@ -156,12 +176,14 @@ class MQTTTransportStage(PipelineStage): logger.debug( "{}({}): UNSUBACK received. completing op.".format(self.name, op.name) ) - operation_flow.complete_op(self, op) + op.complete() self.transport.unsubscribe(topic=op.topic, callback=on_unsubscribed) else: - operation_flow.pass_op_to_next_stage(self, op) + # This code block should not be reached in correct program flow. + # This will raise an error when executed. + self.send_op_down(op) @pipeline_thread.invoke_on_pipeline_thread_nowait def _on_mqtt_message_received(self, topic, payload): @@ -169,9 +191,8 @@ class MQTTTransportStage(PipelineStage): Handler that gets called by the protocol library when an incoming message arrives. Convert that message into a pipeline event and pass it up for someone to handle. """ - operation_flow.pass_event_to_previous_stage( - stage=self, - event=pipeline_events_mqtt.IncomingMQTTMessageEvent(topic=topic, payload=payload), + self.send_event_up( + pipeline_events_mqtt.IncomingMQTTMessageEvent(topic=topic, payload=payload) ) @pipeline_thread.invoke_on_pipeline_thread_nowait @@ -180,22 +201,24 @@ class MQTTTransportStage(PipelineStage): Handler that gets called by the transport when it connects. """ logger.info("_on_mqtt_connected called") - # self.on_connected() tells other pipeline stages that we're connected. Do this before + # Send an event to tell other pipeline stages that we're connected. Do this before # we do anything else (in case upper stages have any "are we connected" logic. - self.on_connected() + self.send_event_up(pipeline_events_base.ConnectedEvent()) if isinstance( self._pending_connection_op, pipeline_ops_base.ConnectOperation - ) or isinstance(self._pending_connection_op, pipeline_ops_base.ReconnectOperation): + ) or isinstance( + self._pending_connection_op, pipeline_ops_base.ReauthorizeConnectionOperation + ): logger.debug("completing connect op") op = self._pending_connection_op self._pending_connection_op = None - operation_flow.complete_op(stage=self, op=op) + op.complete() else: # This should indicate something odd is going on. # If this occurs, either a connect was completed while there was no pending op, # OR that a connect was completed while a disconnect op was pending - logger.warning("Connection was unexpected") + logger.info("Connection was unexpected") @pipeline_thread.invoke_on_pipeline_thread_nowait def _on_mqtt_connection_failure(self, cause): @@ -205,19 +228,22 @@ class MQTTTransportStage(PipelineStage): :param Exception cause: The Exception that caused the connection failure. """ - logger.error("{}: _on_mqtt_connection_failure called: {}".format(self.name, cause)) + logger.info("{}: _on_mqtt_connection_failure called: {}".format(self.name, cause)) if isinstance( self._pending_connection_op, pipeline_ops_base.ConnectOperation - ) or isinstance(self._pending_connection_op, pipeline_ops_base.ReconnectOperation): + ) or isinstance( + self._pending_connection_op, pipeline_ops_base.ReauthorizeConnectionOperation + ): logger.debug("{}: failing connect op".format(self.name)) op = self._pending_connection_op self._pending_connection_op = None - op.error = cause - operation_flow.complete_op(stage=self, op=op) + op.complete(error=cause) else: - logger.warning("{}: Connection failure was unexpected".format(self.name)) - unhandled_exceptions.exception_caught_in_background_thread(cause) + logger.info("{}: Connection failure was unexpected".format(self.name)) + handle_exceptions.swallow_unraised_exception( + cause, log_msg="Unexpected connection failure. Safe to ignore.", log_lvl="info" + ) @pipeline_thread.invoke_on_pipeline_thread_nowait def _on_mqtt_disconnected(self, cause=None): @@ -227,31 +253,47 @@ class MQTTTransportStage(PipelineStage): :param Exception cause: The Exception that caused the disconnection, if any (optional) """ if cause: - logger.error("{}: _on_mqtt_disconnect called: {}".format(self.name, cause)) + logger.info("{}: _on_mqtt_disconnect called: {}".format(self.name, cause)) else: logger.info("{}: _on_mqtt_disconnect called".format(self.name)) - # self.on_disconnected() tells other pipeilne stages that we're disconnected. Do this before - # we do anything else (in case upper stages have any "are we connected" logic. - self.on_disconnected() + # Send an event to tell other pipeilne stages that we're disconnected. Do this before + # we do anything else (in case upper stages have any "are we connected" logic.) + self.send_event_up(pipeline_events_base.DisconnectedEvent()) - if isinstance(self._pending_connection_op, pipeline_ops_base.DisconnectOperation): - logger.debug("{}: completing disconnect op".format(self.name)) + if self._pending_connection_op: + # on_mqtt_disconnected will cause any pending connect op to complete. This is how Paho + # behaves when there is a connection error, and it also makes sense that on_mqtt_disconnected + # would cause a pending connection op to fail. + logger.debug( + "{}: completing pending {} op".format(self.name, self._pending_connection_op.name) + ) op = self._pending_connection_op self._pending_connection_op = None - if cause: - # Only create a ConnnectionDroppedError if there is a cause, - # i.e. unexpected disconnect. - try: - six.raise_from(errors.ConnectionDroppedError, cause) - except errors.ConnectionDroppedError as e: - op.error = e - operation_flow.complete_op(stage=self, op=op) + if isinstance(op, pipeline_ops_base.DisconnectOperation): + # Swallow any errors if we intended to disconnect - even if something went wrong, we + # got to the state we wanted to be in! + if cause: + handle_exceptions.swallow_unraised_exception( + cause, + log_msg="Unexpected disconnect with error while disconnecting - swallowing error", + ) + op.complete() + else: + if cause: + op.complete(error=cause) + else: + op.complete( + error=transport_exceptions.ConnectionDroppedError("transport disconnected") + ) else: - logger.warning("{}: disconnection was unexpected".format(self.name)) - # Regardless of cause, it is now a ConnectionDroppedError - try: - six.raise_from(errors.ConnectionDroppedError, cause) - except errors.ConnectionDroppedError as e: - unhandled_exceptions.exception_caught_in_background_thread(e) + logger.info("{}: disconnection was unexpected".format(self.name)) + # Regardless of cause, it is now a ConnectionDroppedError. log it and swallow it. + # Higher layers will see that we're disconencted and reconnect as necessary. + e = transport_exceptions.ConnectionDroppedError(cause=cause) + handle_exceptions.swallow_unraised_exception( + e, + log_msg="Unexpected disconnection. Safe to ignore since other stages will reconnect.", + log_lvl="info", + ) diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_thread.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_thread.py index aba6af52d..d50611d77 100644 --- a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_thread.py +++ b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_thread.py @@ -9,7 +9,7 @@ import threading import traceback from multiprocessing.pool import ThreadPool from concurrent.futures import ThreadPoolExecutor -from azure.iot.device.common import unhandled_exceptions +from azure.iot.device.common import handle_exceptions logger = logging.getLogger(__name__) @@ -113,7 +113,7 @@ def _invoke_on_executor_thread(func, thread_name, block=True): return func(*args, **kwargs) except Exception as e: if not block: - unhandled_exceptions.exception_caught_in_background_thread(e) + handle_exceptions.handle_background_exception(e) else: raise except BaseException: @@ -166,6 +166,15 @@ def invoke_on_callback_thread_nowait(func): return _invoke_on_executor_thread(func=func, thread_name="callback", block=False) +def invoke_on_http_thread_nowait(func): + """ + Run the decorated function on the callback thread, but don't wait for it to complete + """ + # TODO: Refactor this since this is not in the pipeline thread anymore, so we need to pull this into common. + # Also, the max workers eventually needs to be a bigger number, so that needs to be fixed to allow for more than one HTTP Request a a time. + return _invoke_on_executor_thread(func=func, thread_name="azure_iot_http", block=False) + + def _assert_executor_thread(func, thread_name): """ Decorator which asserts that the given function only gets called inside the given @@ -196,3 +205,10 @@ def runs_on_pipeline_thread(func): Decorator which marks a function as only running inside the pipeline thread. """ return _assert_executor_thread(func=func, thread_name="pipeline") + + +def runs_on_http_thread(func): + """ + Decorator which marks a function as only running inside the http thread. + """ + return _assert_executor_thread(func=func, thread_name="azure_iot_http") diff --git a/azure-iot-device/azure/iot/device/common/transport_exceptions.py b/azure-iot-device/azure/iot/device/common/transport_exceptions.py new file mode 100644 index 000000000..40719471e --- /dev/null +++ b/azure-iot-device/azure/iot/device/common/transport_exceptions.py @@ -0,0 +1,58 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +"""This module defines errors that may be raised from a transport""" + +from .chainable_exception import ChainableException + + +class ConnectionFailedError(ChainableException): + """ + Connection failed to be established + """ + + pass + + +class ConnectionDroppedError(ChainableException): + """ + Previously established connection was dropped + """ + + pass + + +class UnauthorizedError(ChainableException): + """ + Authorization was rejected + """ + + pass + + +class ProtocolClientError(ChainableException): + """ + Error returned from protocol client library + """ + + pass + + +class TlsExchangeAuthError(ChainableException): + """ + Error returned when transport layer exchanges + result in a SSLCertVerification error. + """ + + pass + + +class ProtocolProxyError(ChainableException): + """ + All proxy-related errors. + TODO : Not sure what to name it here. There is a class called Proxy Error already in Pysocks + """ + + pass diff --git a/azure-iot-device/azure/iot/device/constant.py b/azure-iot-device/azure/iot/device/constant.py index ec8b07818..c775ee7c3 100644 --- a/azure-iot-device/azure/iot/device/constant.py +++ b/azure-iot-device/azure/iot/device/constant.py @@ -6,8 +6,10 @@ """This module defines constants for use across the azure-iot-device package """ -VERSION = "2.0.0-preview.10" -USER_AGENT = "py-azure-iot-device/{version}".format(version=VERSION) +VERSION = "2.1.0" +IOTHUB_IDENTIFIER = "azure-iot-device-iothub-py" +PROVISIONING_IDENTIFIER = "azure-iot-device-provisioning-py" IOTHUB_API_VERSION = "2018-06-30" PROVISIONING_API_VERSION = "2019-03-31" SECURITY_MESSAGE_INTERFACE_ID = "urn:azureiot:Security:SecurityAgent:1" +TELEMETRY_MESSAGE_SIZE_LIMIT = 262144 diff --git a/azure-iot-device/azure/iot/device/exceptions.py b/azure-iot-device/azure/iot/device/exceptions.py new file mode 100644 index 000000000..19278390c --- /dev/null +++ b/azure-iot-device/azure/iot/device/exceptions.py @@ -0,0 +1,175 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +"""This module defines an exception surface, exposed as part of the azure.iot.device library API""" + +from azure.iot.device.common.chainable_exception import ChainableException + +# Currently, we are redefining many lower level exceptions in this file, in order to present an API +# surface that will be consistent and unchanging (even though lower level exceptions may change). +# Potentially, this could be somewhat relaxed in the future as the design solidifies. + +# ~~~ EXCEPTIONS ~~~ + + +class OperationCancelled(ChainableException): + """An operation was cancelled""" + + pass + + +# ~~~ CLIENT ERRORS ~~~ + + +class ClientError(ChainableException): + """Generic error for a client""" + + pass + + +class ConnectionFailedError(ClientError): + """Failed to establish a connection""" + + pass + + +class ConnectionDroppedError(ClientError): + """Lost connection while executing operation""" + + pass + + +class CredentialError(ClientError): + """Could not connect client using given credentials""" + + pass + + +# ~~~ SERVICE ERRORS ~~~ + + +class ServiceError(ChainableException): + """Error received from an Azure IoT service""" + + pass + + +# NOTE: These are not (yet) in use. +# Because of this they have been commented out to prevent confusion. + +# class ArgumentError(ServiceError): +# """Service returned 400""" + +# pass + + +# class UnauthorizedError(ServiceError): +# """Service returned 401""" + +# pass + + +# class QuotaExceededError(ServiceError): +# """Service returned 403""" + +# pass + + +# class NotFoundError(ServiceError): +# """Service returned 404""" + +# pass + + +# class DeviceTimeoutError(ServiceError): +# """Service returned 408""" + +# # TODO: is this a method call error? If so, do we retry? +# pass + + +# class DeviceAlreadyExistsError(ServiceError): +# """Service returned 409""" + +# pass + + +# class InvalidEtagError(ServiceError): +# """Service returned 412""" + +# pass + + +# class MessageTooLargeError(ServiceError): +# """Service returned 413""" + +# pass + + +# class ThrottlingError(ServiceError): +# """Service returned 429""" + +# pass + + +# class InternalServiceError(ServiceError): +# """Service returned 500""" + +# pass + + +# class BadDeviceResponseError(ServiceError): +# """Service returned 502""" + +# # TODO: is this a method invoke thing? +# pass + + +# class ServiceUnavailableError(ServiceError): +# """Service returned 503""" + +# pass + + +# class ServiceTimeoutError(ServiceError): +# """Service returned 504""" + +# pass + + +# class FailedStatusCodeError(ServiceError): +# """Service returned unknown status code""" + +# pass + + +# status_code_to_error = { +# 400: ArgumentError, +# 401: UnauthorizedError, +# 403: QuotaExceededError, +# 404: NotFoundError, +# 408: DeviceTimeoutError, +# 409: DeviceAlreadyExistsError, +# 412: InvalidEtagError, +# 413: MessageTooLargeError, +# 429: ThrottlingError, +# 500: InternalServiceError, +# 502: BadDeviceResponseError, +# 503: ServiceUnavailableError, +# 504: ServiceTimeoutError, +# } + + +# def error_from_status_code(status_code, message=None): +# """ +# Return an Error object from a failed status code + +# :param int status_code: Status code returned from failed operation +# :returns: Error object +# """ +# if status_code in status_code_to_error: +# return status_code_to_error[status_code](message) +# else: +# return FailedStatusCodeError(message) diff --git a/azure-iot-device/azure/iot/device/iothub/__init__.py b/azure-iot-device/azure/iot/device/iothub/__init__.py index 6fc8b4045..f81aea8b3 100644 --- a/azure-iot-device/azure/iot/device/iothub/__init__.py +++ b/azure-iot-device/azure/iot/device/iothub/__init__.py @@ -5,7 +5,6 @@ as a Device or Module. """ from .sync_clients import IoTHubDeviceClient, IoTHubModuleClient -from .sync_inbox import InboxEmpty -from .models import Message, MethodResponse +from .models import Message, MethodRequest, MethodResponse -__all__ = ["IoTHubDeviceClient", "IoTHubModuleClient", "Message", "InboxEmpty", "MethodResponse"] +__all__ = ["IoTHubDeviceClient", "IoTHubModuleClient", "Message", "MethodRequest", "MethodResponse"] diff --git a/azure-iot-device/azure/iot/device/iothub/abstract_clients.py b/azure-iot-device/azure/iot/device/iothub/abstract_clients.py index a78522102..4fd86dbeb 100644 --- a/azure-iot-device/azure/iot/device/iothub/abstract_clients.py +++ b/azure-iot-device/azure/iot/device/iothub/abstract_clients.py @@ -14,7 +14,6 @@ import io from . import auth from . import pipeline - logger = logging.getLogger(__name__) # A note on implementation: @@ -23,57 +22,102 @@ logger = logging.getLogger(__name__) # pipeline configuration to be specifically tailored to the method of instantiation. # For instance, .create_from_connection_string and .create_from_edge_envrionment both can use # SymmetricKeyAuthenticationProviders to instantiate pipeline(s), but only .create_from_edge_environment -# should use it to instantiate an EdgePipeline. If the initializer accepted an auth provider, and then +# should use it to instantiate an HTTPPipeline. If the initializer accepted an auth provider, and then # used it to create pipelines, this detail would be lost, as there would be no way to tell if a # SymmetricKeyAuthenticationProvider was intended to be part of an Edge scenario or not. +def _validate_kwargs(**kwargs): + """Helper function to validate user provided kwargs. + Raises TypeError if an invalid option has been provided""" + valid_kwargs = [ + "product_info", + "websockets", + "cipher", + "server_verification_cert", + "proxy_options", + ] + + for kwarg in kwargs: + if kwarg not in valid_kwargs: + raise TypeError("Got an unexpected keyword argument '{}'".format(kwarg)) + + +def _get_pipeline_config_kwargs(**kwargs): + """Helper function to get a subset of user provided kwargs relevant to IoTHubPipelineConfig""" + new_kwargs = {} + if "product_info" in kwargs: + new_kwargs["product_info"] = kwargs["product_info"] + if "websockets" in kwargs: + new_kwargs["websockets"] = kwargs["websockets"] + if "cipher" in kwargs: + new_kwargs["cipher"] = kwargs["cipher"] + if "proxy_options" in kwargs: + new_kwargs["proxy_options"] = kwargs["proxy_options"] + return new_kwargs + + @six.add_metaclass(abc.ABCMeta) class AbstractIoTHubClient(object): - """A superclass representing a generic client. This class needs to be extended for specific clients.""" + """ A superclass representing a generic IoTHub client. + This class needs to be extended for specific clients. + """ - def __init__(self, iothub_pipeline): + def __init__(self, iothub_pipeline, http_pipeline): """Initializer for a generic client. :param iothub_pipeline: The pipeline used to connect to the IoTHub endpoint. - :type iothub_pipeline: IoTHubPipeline + :type iothub_pipeline: :class:`azure.iot.device.iothub.pipeline.IoTHubPipeline` """ self._iothub_pipeline = iothub_pipeline - self._edge_pipeline = None + self._http_pipeline = http_pipeline @classmethod - def create_from_connection_string(cls, connection_string, ca_cert=None): + def create_from_connection_string(cls, connection_string, **kwargs): """ Instantiate the client from a IoTHub device or module connection string. :param str connection_string: The connection string for the IoTHub you wish to connect to. - :param str ca_cert: (OPTIONAL) The trusted certificate chain. Necessary when using a - connection string with a GatewayHostName parameter. + + :param str server_verification_cert: Configuration Option. The trusted certificate chain. + Necessary when using connecting to an endpoint which has a non-standard root of trust, + such as a protocol gateway. + :param bool websockets: Configuration Option. Default is False. Set to true if using MQTT + over websockets. + :param cipher: Configuration Option. Cipher suite(s) for TLS/SSL, as a string in + "OpenSSL cipher list format" or as a list of cipher suite strings. + :type cipher: str or list(str) + :param str product_info: Configuration Option. Default is empty string. The string contains + arbitrary product info which is appended to the user agent string. + :param proxy_options: Options for sending traffic through proxy servers. + :type ProxyOptions: :class:`azure.iot.device.common.proxy_options` :raises: ValueError if given an invalid connection_string. + :raises: TypeError if given an unrecognized parameter. + + :returns: An instance of an IoTHub client that uses a connection string for authentication. """ # TODO: Make this device/module specific and reject non-matching connection strings. # This will require refactoring of the auth package to use common objects (e.g. ConnectionString) # in order to differentiate types of connection strings. + + _validate_kwargs(**kwargs) + + # Pipeline Config setup + pipeline_config_kwargs = _get_pipeline_config_kwargs(**kwargs) + pipeline_configuration = pipeline.IoTHubPipelineConfig(**pipeline_config_kwargs) + if cls.__name__ == "IoTHubDeviceClient": + pipeline_configuration.blob_upload = True + + # Auth Provider setup authentication_provider = auth.SymmetricKeyAuthenticationProvider.parse(connection_string) - authentication_provider.ca_cert = ca_cert # TODO: make this part of the instantiation - iothub_pipeline = pipeline.IoTHubPipeline(authentication_provider) - return cls(iothub_pipeline) + authentication_provider.server_verification_cert = kwargs.get("server_verification_cert") - @classmethod - def create_from_shared_access_signature(cls, sas_token): - """ - Instantiate the client from a Shared Access Signature (SAS) token. + # Pipeline setup + http_pipeline = pipeline.HTTPPipeline(authentication_provider, pipeline_configuration) + iothub_pipeline = pipeline.IoTHubPipeline(authentication_provider, pipeline_configuration) - This method of instantiation is not recommended for general usage. - - :param str sas_token: The string representation of a SAS token. - - :raises: ValueError if given an invalid sas_token - """ - authentication_provider = auth.SharedAccessSignatureAuthenticationProvider.parse(sas_token) - iothub_pipeline = pipeline.IoTHubPipeline(authentication_provider) - return cls(iothub_pipeline) + return cls(iothub_pipeline, http_pipeline) @abc.abstractmethod def connect(self): @@ -107,25 +151,109 @@ class AbstractIoTHubClient(object): def receive_twin_desired_properties_patch(self): pass + @property + def connected(self): + """ + Read-only property to indicate if the transport is connected or not. + """ + return self._iothub_pipeline.connected + @six.add_metaclass(abc.ABCMeta) class AbstractIoTHubDeviceClient(AbstractIoTHubClient): @classmethod - def create_from_x509_certificate(cls, x509, hostname, device_id): + def create_from_x509_certificate(cls, x509, hostname, device_id, **kwargs): """ Instantiate a client which using X509 certificate authentication. - :param hostname: Host running the IotHub. Can be found in the Azure portal in the Overview tab as the string hostname. - :param x509: The complete x509 certificate object, To use the certificate the enrollment object needs to contain cert (either the root certificate or one of the intermediate CA certificates). - If the cert comes from a CER file, it needs to be base64 encoded. - :type x509: X509 - :param device_id: The ID is used to uniquely identify a device in the IoTHub - :return: A IoTHubClient which can use X509 authentication. + + :param str hostname: Host running the IotHub. + Can be found in the Azure portal in the Overview tab as the string hostname. + :param x509: The complete x509 certificate object. + To use the certificate the enrollment object needs to contain cert + (either the root certificate or one of the intermediate CA certificates). + If the cert comes from a CER file, it needs to be base64 encoded. + :type x509: :class:`azure.iot.device.X509` + :param str device_id: The ID used to uniquely identify a device in the IoTHub + + :param str server_verification_cert: Configuration Option. The trusted certificate chain. + Necessary when using connecting to an endpoint which has a non-standard root of trust, + such as a protocol gateway. + :param bool websockets: Configuration Option. Default is False. Set to true if using MQTT + over websockets. + :param cipher: Configuration Option. Cipher suite(s) for TLS/SSL, as a string in + "OpenSSL cipher list format" or as a list of cipher suite strings. + :type cipher: str or list(str) + :param str product_info: Configuration Option. Default is empty string. The string contains + arbitrary product info which is appended to the user agent string. + :param proxy_options: Options for sending traffic through proxy servers. + :type ProxyOptions: :class:`azure.iot.device.common.proxy_options` + + :raises: TypeError if given an unrecognized parameter. + + :returns: An instance of an IoTHub client that uses an X509 certificate for authentication. """ + _validate_kwargs(**kwargs) + + # Pipeline Config setup + pipeline_config_kwargs = _get_pipeline_config_kwargs(**kwargs) + pipeline_configuration = pipeline.IoTHubPipelineConfig(**pipeline_config_kwargs) + pipeline_configuration.blob_upload = True # Blob Upload is a feature on Device Clients + + # Auth Provider setup authentication_provider = auth.X509AuthenticationProvider( x509=x509, hostname=hostname, device_id=device_id ) - iothub_pipeline = pipeline.IoTHubPipeline(authentication_provider) - return cls(iothub_pipeline) + authentication_provider.server_verification_cert = kwargs.get("server_verification_cert") + + # Pipeline setup + http_pipeline = pipeline.HTTPPipeline(authentication_provider, pipeline_configuration) + iothub_pipeline = pipeline.IoTHubPipeline(authentication_provider, pipeline_configuration) + + return cls(iothub_pipeline, http_pipeline) + + @classmethod + def create_from_symmetric_key(cls, symmetric_key, hostname, device_id, **kwargs): + """ + Instantiate a client using symmetric key authentication. + + :param symmetric_key: The symmetric key. + :param str hostname: Host running the IotHub. + Can be found in the Azure portal in the Overview tab as the string hostname. + :param device_id: The device ID + + :param str server_verification_cert: Configuration Option. The trusted certificate chain. + Necessary when using connecting to an endpoint which has a non-standard root of trust, + such as a protocol gateway. + :param bool websockets: Configuration Option. Default is False. Set to true if using MQTT + over websockets. + :param cipher: Configuration Option. Cipher suite(s) for TLS/SSL, as a string in + "OpenSSL cipher list format" or as a list of cipher suite strings. + :type cipher: str or list(str) + :param str product_info: Configuration Option. Default is empty string. The string contains + arbitrary product info which is appended to the user agent string. + + :raises: TypeError if given an unrecognized parameter. + + :return: An instance of an IoTHub client that uses a symmetric key for authentication. + """ + _validate_kwargs(**kwargs) + + # Pipeline Config setup + pipeline_config_kwargs = _get_pipeline_config_kwargs(**kwargs) + pipeline_configuration = pipeline.IoTHubPipelineConfig(**pipeline_config_kwargs) + pipeline_configuration.blob_upload = True # Blob Upload is a feature on Device Clients + + # Auth Provider setup + authentication_provider = auth.SymmetricKeyAuthenticationProvider( + hostname=hostname, device_id=device_id, module_id=None, shared_access_key=symmetric_key + ) + authentication_provider.server_verification_cert = kwargs.get("server_verification_cert") + + # Pipeline setup + http_pipeline = pipeline.HTTPPipeline(authentication_provider, pipeline_configuration) + iothub_pipeline = pipeline.IoTHubPipeline(authentication_provider, pipeline_configuration) + + return cls(iothub_pipeline, http_pipeline) @abc.abstractmethod def receive_message(self): @@ -134,28 +262,42 @@ class AbstractIoTHubDeviceClient(AbstractIoTHubClient): @six.add_metaclass(abc.ABCMeta) class AbstractIoTHubModuleClient(AbstractIoTHubClient): - def __init__(self, iothub_pipeline, edge_pipeline=None): + def __init__(self, iothub_pipeline, http_pipeline): """Initializer for a module client. :param iothub_pipeline: The pipeline used to connect to the IoTHub endpoint. - :type iothub_pipeline: IoTHubPipeline - :param edge_pipeline: (OPTIONAL) The pipeline used to connect to the Edge endpoint. - :type edge_pipeline: EdgePipeline + :type iothub_pipeline: :class:`azure.iot.device.iothub.pipeline.IoTHubPipeline` """ - super(AbstractIoTHubModuleClient, self).__init__(iothub_pipeline) - self._edge_pipeline = edge_pipeline + super(AbstractIoTHubModuleClient, self).__init__(iothub_pipeline, http_pipeline) @classmethod - def create_from_edge_environment(cls): + def create_from_edge_environment(cls, **kwargs): """ Instantiate the client from the IoT Edge environment. This method can only be run from inside an IoT Edge container, or in a debugging environment configured for Edge development (e.g. Visual Studio, Visual Studio Code) - :raises: IoTEdgeError if the IoT Edge container is not configured correctly. - :raises: ValueError if debug variables are invalid + :param bool websockets: Configuration Option. Default is False. Set to true if using MQTT + over websockets. + :param cipher: Configuration Option. Cipher suite(s) for TLS/SSL, as a string in + "OpenSSL cipher list format" or as a list of cipher suite strings. + :type cipher: str or list(str) + :param str product_info: Configuration Option. Default is empty string. The string contains + arbitrary product info which is appended to the user agent string. + + :raises: OSError if the IoT Edge container is not configured correctly. + :raises: ValueError if debug variables are invalid. + + :returns: An instance of an IoTHub client that uses the IoT Edge environment for + authentication. """ + _validate_kwargs(**kwargs) + if kwargs.get("server_verification_cert"): + raise TypeError( + "'server_verification_cert' is not supported by clients using an IoT Edge environment" + ) + # First try the regular Edge container variables try: hostname = os.environ["IOTEDGE_IOTHUBHOSTNAME"] @@ -172,16 +314,16 @@ class AbstractIoTHubModuleClient(AbstractIoTHubClient): try: connection_string = os.environ["EdgeHubConnectionString"] ca_cert_filepath = os.environ["EdgeModuleCACertificateFile"] - except KeyError: - # TODO: consider using a different error here. (OSError?) - raise auth.IoTEdgeError("IoT Edge environment not configured correctly") - - # TODO: variant ca_cert file vs data object that would remove the need for this fopen + except KeyError as e: + new_err = OSError("IoT Edge environment not configured correctly") + new_err.__cause__ = e + raise new_err + # TODO: variant server_verification_cert file vs data object that would remove the need for this fopen # Read the certificate file to pass it on as a string try: with io.open(ca_cert_filepath, mode="r") as ca_cert_file: - ca_cert = ca_cert_file.read() - except (OSError, IOError): + server_verification_cert = ca_cert_file.read() + except (OSError, IOError) as e: # In Python 2, a non-existent file raises IOError, and an invalid file raises an IOError. # In Python 3, a non-existent file raises FileNotFoundError, and an invalid file raises an OSError. # However, FileNotFoundError inherits from OSError, and IOError has been turned into an alias for OSError, @@ -189,44 +331,93 @@ class AbstractIoTHubModuleClient(AbstractIoTHubClient): # Unfortunately, we can't distinguish cause of error from error type, so the raised ValueError has a generic # message. If, in the future, we want to add detail, this could be accomplished by inspecting the e.errno # attribute - raise ValueError("Invalid CA certificate file") + new_err = ValueError("Invalid CA certificate file") + new_err.__cause__ = e + raise new_err # Use Symmetric Key authentication for local dev experience. - authentication_provider = auth.SymmetricKeyAuthenticationProvider.parse( - connection_string - ) - authentication_provider.ca_cert = ca_cert + try: + authentication_provider = auth.SymmetricKeyAuthenticationProvider.parse( + connection_string + ) + except ValueError: + raise + authentication_provider.server_verification_cert = server_verification_cert else: # Use an HSM for authentication in the general case - authentication_provider = auth.IoTEdgeAuthenticationProvider( - hostname=hostname, - device_id=device_id, - module_id=module_id, - gateway_hostname=gateway_hostname, - module_generation_id=module_generation_id, - workload_uri=workload_uri, - api_version=api_version, - ) - iothub_pipeline = pipeline.IoTHubPipeline(authentication_provider) - edge_pipeline = pipeline.EdgePipeline(authentication_provider) - return cls(iothub_pipeline, edge_pipeline=edge_pipeline) + try: + authentication_provider = auth.IoTEdgeAuthenticationProvider( + hostname=hostname, + device_id=device_id, + module_id=module_id, + gateway_hostname=gateway_hostname, + module_generation_id=module_generation_id, + workload_uri=workload_uri, + api_version=api_version, + ) + except auth.IoTEdgeError as e: + new_err = OSError("Unexpected failure in IoTEdge") + new_err.__cause__ = e + raise new_err + + # Pipeline Config setup + pipeline_config_kwargs = _get_pipeline_config_kwargs(**kwargs) + pipeline_configuration = pipeline.IoTHubPipelineConfig(**pipeline_config_kwargs) + pipeline_configuration.method_invoke = ( + True + ) # Method Invoke is allowed on modules created from edge environment + + # Pipeline setup + http_pipeline = pipeline.HTTPPipeline(authentication_provider, pipeline_configuration) + iothub_pipeline = pipeline.IoTHubPipeline(authentication_provider, pipeline_configuration) + + return cls(iothub_pipeline, http_pipeline) @classmethod - def create_from_x509_certificate(cls, x509, hostname, device_id, module_id): + def create_from_x509_certificate(cls, x509, hostname, device_id, module_id, **kwargs): """ Instantiate a client which using X509 certificate authentication. - :param hostname: Host running the IotHub. Can be found in the Azure portal in the Overview tab as the string hostname. - :param x509: The complete x509 certificate object, To use the certificate the enrollment object needs to contain cert (either the root certificate or one of the intermediate CA certificates). - If the cert comes from a CER file, it needs to be base64 encoded. - :type x509: X509 - :param device_id: The ID is used to uniquely identify a device in the IoTHub - :param module_id : The ID of the module to uniquely identify a module on a device on the IoTHub. - :return: A IoTHubClient which can use X509 authentication. + + :param str hostname: Host running the IotHub. + Can be found in the Azure portal in the Overview tab as the string hostname. + :param x509: The complete x509 certificate object. + To use the certificate the enrollment object needs to contain cert + (either the root certificate or one of the intermediate CA certificates). + If the cert comes from a CER file, it needs to be base64 encoded. + :type x509: :class:`azure.iot.device.X509` + :param str device_id: The ID used to uniquely identify a device in the IoTHub + :param str module_id: The ID used to uniquely identify a module on a device on the IoTHub. + + :param str server_verification_cert: Configuration Option. The trusted certificate chain. + Necessary when using connecting to an endpoint which has a non-standard root of trust, + such as a protocol gateway. + :param bool websockets: Configuration Option. Default is False. Set to true if using MQTT + over websockets. + :param cipher: Configuration Option. Cipher suite(s) for TLS/SSL, as a string in + "OpenSSL cipher list format" or as a list of cipher suite strings. + :type cipher: str or list(str) + :param str product_info: Configuration Option. Default is empty string. The string contains + arbitrary product info which is appended to the user agent string. + + :raises: TypeError if given an unrecognized parameter. + + :returns: An instance of an IoTHub client that uses an X509 certificate for authentication. """ + _validate_kwargs(**kwargs) + + # Pipeline Config setup + pipeline_config_kwargs = _get_pipeline_config_kwargs(**kwargs) + pipeline_configuration = pipeline.IoTHubPipelineConfig(**pipeline_config_kwargs) + + # Auth Provider setup authentication_provider = auth.X509AuthenticationProvider( x509=x509, hostname=hostname, device_id=device_id, module_id=module_id ) - iothub_pipeline = pipeline.IoTHubPipeline(authentication_provider) - return cls(iothub_pipeline) + authentication_provider.server_verification_cert = kwargs.get("server_verification_cert") + + # Pipeline setup + http_pipeline = pipeline.HTTPPipeline(authentication_provider, pipeline_configuration) + iothub_pipeline = pipeline.IoTHubPipeline(authentication_provider, pipeline_configuration) + return cls(iothub_pipeline, http_pipeline) @abc.abstractmethod def send_message_to_output(self, message, output_name): diff --git a/azure-iot-device/azure/iot/device/iothub/aio/async_clients.py b/azure-iot-device/azure/iot/device/iothub/aio/async_clients.py index 402083dca..4a652a63f 100644 --- a/azure-iot-device/azure/iot/device/iothub/aio/async_clients.py +++ b/azure-iot-device/azure/iot/device/iothub/aio/async_clients.py @@ -16,12 +16,38 @@ from azure.iot.device.iothub.abstract_clients import ( ) from azure.iot.device.iothub.models import Message from azure.iot.device.iothub.pipeline import constant +from azure.iot.device.iothub.pipeline import exceptions as pipeline_exceptions +from azure.iot.device import exceptions from azure.iot.device.iothub.inbox_manager import InboxManager from .async_inbox import AsyncClientInbox +from azure.iot.device import constant as device_constant logger = logging.getLogger(__name__) +async def handle_result(callback): + try: + return await callback.completion() + except pipeline_exceptions.ConnectionDroppedError as e: + raise exceptions.ConnectionDroppedError(message="Lost connection to IoTHub", cause=e) + except pipeline_exceptions.ConnectionFailedError as e: + raise exceptions.ConnectionFailedError(message="Could not connect to IoTHub", cause=e) + except pipeline_exceptions.UnauthorizedError as e: + raise exceptions.CredentialError(message="Credentials invalid, could not connect", cause=e) + except pipeline_exceptions.ProtocolClientError as e: + raise exceptions.ClientError(message="Error in the IoTHub client", cause=e) + except pipeline_exceptions.TlsExchangeAuthError as e: + raise exceptions.ClientError( + message="Error in the IoTHub client due to TLS exchanges.", cause=e + ) + except pipeline_exceptions.ProtocolProxyError as e: + raise exceptions.ClientError( + message="Error in the IoTHub client raised due to proxy connections.", cause=e + ) + except Exception as e: + raise exceptions.ClientError(message="Unexpected failure", cause=e) + + class GenericIoTHubClient(AbstractIoTHubClient): """A super class representing a generic asynchronous client. This class needs to be extended for specific clients. @@ -33,8 +59,10 @@ class GenericIoTHubClient(AbstractIoTHubClient): This initializer should not be called directly. Instead, use one of the 'create_from_' classmethods to instantiate - TODO: How to document kwargs? - Possible values: iothub_pipeline, edge_pipeline + :param iothub_pipeline: The IoTHubPipeline used for the client + :type iothub_pipeline: :class:`azure.iot.device.iothub.pipeline.IoTHubPipeline` + :param http_pipeline: The HTTPPipeline used for the client + :type http_pipeline: :class:`azure.iot.device.iothub.pipeline.HTTPPipeline` """ # Depending on the subclass calling this __init__, there could be different arguments, # and the super() call could call a different class, due to the different MROs @@ -62,25 +90,37 @@ class GenericIoTHubClient(AbstractIoTHubClient): The destination is chosen based on the credentials passed via the auth_provider parameter that was provided when this object was initialized. + + :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid + and a connection cannot be established. + :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a + connection results in failure. + :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost + during execution. + :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure + during execution. """ logger.info("Connecting to Hub...") connect_async = async_adapter.emulate_async(self._iothub_pipeline.connect) callback = async_adapter.AwaitableCallback() await connect_async(callback=callback) - await callback.completion() + await handle_result(callback) logger.info("Successfully connected to Hub") async def disconnect(self): """Disconnect the client from the Azure IoT Hub or Azure IoT Edge Hub instance. + + :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure + during execution. """ logger.info("Disconnecting from Hub...") disconnect_async = async_adapter.emulate_async(self._iothub_pipeline.disconnect) callback = async_adapter.AwaitableCallback() await disconnect_async(callback=callback) - await callback.completion() + await handle_result(callback) logger.info("Successfully disconnected from Hub") @@ -91,17 +131,31 @@ class GenericIoTHubClient(AbstractIoTHubClient): function will open the connection before sending the event. :param message: The actual message to send. Anything passed that is not an instance of the - Message class will be converted to Message object. + Message class will be converted to Message object. + :type message: :class:`azure.iot.device.Message` or str + + :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid + and a connection cannot be established. + :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a + connection results in failure. + :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost + during execution. + :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure + during execution. + :raises: ValueError if the message fails size validation. """ if not isinstance(message, Message): message = Message(message) + if message.get_size() > device_constant.TELEMETRY_MESSAGE_SIZE_LIMIT: + raise ValueError("Size of telemetry message can not exceed 256 KB.") + logger.info("Sending message to Hub...") send_message_async = async_adapter.emulate_async(self._iothub_pipeline.send_message) callback = async_adapter.AwaitableCallback() await send_message_async(message, callback=callback) - await callback.completion() + await handle_result(callback) logger.info("Successfully sent message to Hub") @@ -111,10 +165,11 @@ class GenericIoTHubClient(AbstractIoTHubClient): If no method request is yet available, will wait until it is available. :param str method_name: Optionally provide the name of the method to receive requests for. - If this parameter is not given, all methods not already being specifically targeted by - a different call to receive_method will be received. + If this parameter is not given, all methods not already being specifically targeted by + a different call to receive_method will be received. :returns: MethodRequest object representing the received method request. + :rtype: `azure.iot.device.MethodRequest` """ if not self._iothub_pipeline.feature_enabled[constant.METHODS]: await self._enable_feature(constant.METHODS) @@ -133,6 +188,16 @@ class GenericIoTHubClient(AbstractIoTHubClient): function will open the connection before sending the event. :param method_response: The MethodResponse to send + :type method_response: :class:`azure.iot.device.MethodResponse` + + :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid + and a connection cannot be established. + :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a + connection results in failure. + :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost + during execution. + :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure + during execution. """ logger.info("Sending method response to Hub...") send_method_response_async = async_adapter.emulate_async( @@ -143,7 +208,7 @@ class GenericIoTHubClient(AbstractIoTHubClient): # TODO: maybe consolidate method_request, result and status into a new object await send_method_response_async(method_response, callback=callback) - await callback.completion() + await handle_result(callback) logger.info("Successfully sent method response to Hub") @@ -151,14 +216,14 @@ class GenericIoTHubClient(AbstractIoTHubClient): """Enable an Azure IoT Hub feature :param feature_name: The name of the feature to enable. - See azure.iot.device.common.pipeline.constant for possible values. + See azure.iot.device.common.pipeline.constant for possible values. """ logger.info("Enabling feature:" + feature_name + "...") enable_feature_async = async_adapter.emulate_async(self._iothub_pipeline.enable_feature) callback = async_adapter.AwaitableCallback() await enable_feature_async(feature_name, callback=callback) - await callback.completion() + await handle_result(callback) logger.info("Successfully enabled feature:" + feature_name) @@ -166,7 +231,17 @@ class GenericIoTHubClient(AbstractIoTHubClient): """ Gets the device or module twin from the Azure IoT Hub or Azure IoT Edge Hub service. - :returns: Twin object which was retrieved from the hub + :returns: Complete Twin as a JSON dict + :rtype: dict + + :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid + and a connection cannot be established. + :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a + connection results in failure. + :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost + during execution. + :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure + during execution. """ logger.info("Getting twin") @@ -177,7 +252,7 @@ class GenericIoTHubClient(AbstractIoTHubClient): callback = async_adapter.AwaitableCallback(return_arg_name="twin") await get_twin_async(callback=callback) - twin = await callback.completion() + twin = await handle_result(callback) logger.info("Successfully retrieved twin") return twin @@ -188,8 +263,17 @@ class GenericIoTHubClient(AbstractIoTHubClient): If the service returns an error on the patch operation, this function will raise the appropriate error. - :param reported_properties_patch: - :type reported_properties_patch: dict, str, int, float, bool, or None (JSON compatible values) + :param reported_properties_patch: Twin Reported Properties patch as a JSON dict + :type reported_properties_patch: dict + + :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid + and a connection cannot be established. + :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a + connection results in failure. + :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost + during execution. + :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure + during execution. """ logger.info("Patching twin reported properties") @@ -202,7 +286,7 @@ class GenericIoTHubClient(AbstractIoTHubClient): callback = async_adapter.AwaitableCallback() await patch_twin_async(patch=reported_properties_patch, callback=callback) - await callback.completion() + await handle_result(callback) logger.info("Successfully sent twin patch") @@ -212,7 +296,8 @@ class GenericIoTHubClient(AbstractIoTHubClient): If no method request is yet available, will wait until it is available. - :returns: desired property patch. This can be dict, str, int, float, bool, or None (JSON compatible values) + :returns: Twin Desired Properties patch as a JSON dict + :rtype: dict """ if not self._iothub_pipeline.feature_enabled[constant.TWIN_PATCHES]: await self._enable_feature(constant.TWIN_PATCHES) @@ -223,6 +308,48 @@ class GenericIoTHubClient(AbstractIoTHubClient): logger.info("twin patch received") return patch + async def get_storage_info_for_blob(self, blob_name): + """Sends a POST request over HTTP to an IoTHub endpoint that will return information for uploading via the Azure Storage Account linked to the IoTHub your device is connected to. + + :param str blob_name: The name in string format of the blob that will be uploaded using the storage API. This name will be used to generate the proper credentials for Storage, and needs to match what will be used with the Azure Storage SDK to perform the blob upload. + + :returns: A JSON-like (dictionary) object from IoT Hub that will contain relevant information including: correlationId, hostName, containerName, blobName, sasToken. + """ + get_storage_info_for_blob_async = async_adapter.emulate_async( + self._http_pipeline.get_storage_info_for_blob + ) + + callback = async_adapter.AwaitableCallback(return_arg_name="storage_info") + await get_storage_info_for_blob_async(blob_name=blob_name, callback=callback) + storage_info = await handle_result(callback) + logger.info("Successfully retrieved storage_info") + return storage_info + + async def notify_blob_upload_status( + self, correlation_id, is_success, status_code, status_description + ): + """When the upload is complete, the device sends a POST request to the IoT Hub endpoint with information on the status of an upload to blob attempt. This is used by IoT Hub to notify listening clients. + + :param str correlation_id: Provided by IoT Hub on get_storage_info_for_blob request. + :param bool is_success: A boolean that indicates whether the file was uploaded successfully. + :param int status_code: A numeric status code that is the status for the upload of the fiel to storage. + :param str status_description: A description that corresponds to the status_code. + """ + notify_blob_upload_status_async = async_adapter.emulate_async( + self._http_pipeline.notify_blob_upload_status + ) + + callback = async_adapter.AwaitableCallback() + await notify_blob_upload_status_async( + correlation_id=correlation_id, + is_success=is_success, + status_code=status_code, + status_description=status_description, + callback=callback, + ) + await handle_result(callback) + logger.info("Successfully notified blob upload status") + class IoTHubDeviceClient(GenericIoTHubClient, AbstractIoTHubDeviceClient): """An asynchronous device client that connects to an Azure IoT Hub instance. @@ -230,16 +357,16 @@ class IoTHubDeviceClient(GenericIoTHubClient, AbstractIoTHubDeviceClient): Intended for usage with Python 3.5.3+ """ - def __init__(self, iothub_pipeline): + def __init__(self, iothub_pipeline, http_pipeline): """Initializer for a IoTHubDeviceClient. This initializer should not be called directly. Instead, use one of the 'create_from_' classmethods to instantiate :param iothub_pipeline: The pipeline used to connect to the IoTHub endpoint. - :type iothub_pipeline: IoTHubPipeline + :type iothub_pipeline: :class:`azure.iot.device.iothub.pipeline.IoTHubPipeline` """ - super().__init__(iothub_pipeline=iothub_pipeline) + super().__init__(iothub_pipeline=iothub_pipeline, http_pipeline=http_pipeline) self._iothub_pipeline.on_c2d_message_received = self._inbox_manager.route_c2d_message async def receive_message(self): @@ -248,6 +375,7 @@ class IoTHubDeviceClient(GenericIoTHubClient, AbstractIoTHubDeviceClient): If no message is yet available, will wait until an item is available. :returns: Message that was sent from the Azure IoT Hub. + :rtype: :class:`azure.iot.device.Message` """ if not self._iothub_pipeline.feature_enabled[constant.C2D_MSG]: await self._enable_feature(constant.C2D_MSG) @@ -265,18 +393,16 @@ class IoTHubModuleClient(GenericIoTHubClient, AbstractIoTHubModuleClient): Intended for usage with Python 3.5.3+ """ - def __init__(self, iothub_pipeline, edge_pipeline=None): + def __init__(self, iothub_pipeline, http_pipeline): """Intializer for a IoTHubModuleClient. This initializer should not be called directly. Instead, use one of the 'create_from_' classmethods to instantiate :param iothub_pipeline: The pipeline used to connect to the IoTHub endpoint. - :type iothub_pipeline: IoTHubPipeline - :param edge_pipeline: (OPTIONAL) The pipeline used to connect to the Edge endpoint. - :type edge_pipeline: EdgePipeline + :type iothub_pipeline: :class:`azure.iot.device.iothub.pipeline.IoTHubPipeline` """ - super().__init__(iothub_pipeline=iothub_pipeline, edge_pipeline=edge_pipeline) + super().__init__(iothub_pipeline=iothub_pipeline, http_pipeline=http_pipeline) self._iothub_pipeline.on_input_message_received = self._inbox_manager.route_input_message async def send_message_to_output(self, message, output_name): @@ -287,13 +413,27 @@ class IoTHubModuleClient(GenericIoTHubClient, AbstractIoTHubModuleClient): If the connection to the service has not previously been opened by a call to connect, this function will open the connection before sending the event. - :param message: message to send to the given output. Anything passed that is not an instance of the - Message class will be converted to Message object. - :param output_name: Name of the output to send the event to. + :param message: Message to send to the given output. Anything passed that is not an + instance of the Message class will be converted to Message object. + :type message: :class:`azure.iot.device.Message` or str + :param str output_name: Name of the output to send the event to. + + :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid + and a connection cannot be established. + :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a + connection results in failure. + :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost + during execution. + :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure + during execution. + :raises: ValueError if the message fails size validation. """ if not isinstance(message, Message): message = Message(message) + if message.get_size() > device_constant.TELEMETRY_MESSAGE_SIZE_LIMIT: + raise ValueError("Size of message can not exceed 256 KB.") + message.output_name = output_name logger.info("Sending message to output:" + output_name + "...") @@ -303,7 +443,7 @@ class IoTHubModuleClient(GenericIoTHubClient, AbstractIoTHubModuleClient): callback = async_adapter.AwaitableCallback() await send_output_event_async(message, callback=callback) - await callback.completion() + await handle_result(callback) logger.info("Successfully sent message to output: " + output_name) @@ -313,7 +453,9 @@ class IoTHubModuleClient(GenericIoTHubClient, AbstractIoTHubModuleClient): If no message is yet available, will wait until an item is available. :param str input_name: The input name to receive a message on. + :returns: Message that was sent to the specified input. + :rtype: :class:`azure.iot.device.Message` """ if not self._iothub_pipeline.feature_enabled[constant.INPUT_MSG]: await self._enable_feature(constant.INPUT_MSG) @@ -323,3 +465,21 @@ class IoTHubModuleClient(GenericIoTHubClient, AbstractIoTHubModuleClient): message = await inbox.get() logger.info("Input message received on: " + input_name) return message + + async def invoke_method(self, method_params, device_id, module_id=None): + """Invoke a method from your client onto a device or module client, and receive the response to the method call. + + :param dict method_params: Should contain a method_name, payload, connect_timeout_in_seconds, response_timeout_in_seconds. + :param str device_id: Device ID of the target device where the method will be invoked. + :param str module_id: Module ID of the target module where the method will be invoked. (Optional) + + :returns: method_result should contain a status, and a payload + :rtype: dict + """ + invoke_method_async = async_adapter.emulate_async(self._http_pipeline.invoke_method) + callback = async_adapter.AwaitableCallback(return_arg_name="invoke_method_response") + await invoke_method_async(device_id, method_params, callback=callback, module_id=module_id) + + method_response = await handle_result(callback) + logger.info("Successfully invoked method") + return method_response diff --git a/azure-iot-device/azure/iot/device/iothub/auth/base_renewable_token_authentication_provider.py b/azure-iot-device/azure/iot/device/iothub/auth/base_renewable_token_authentication_provider.py index a2eb981e1..ba8eaec9a 100644 --- a/azure-iot-device/azure/iot/device/iothub/auth/base_renewable_token_authentication_provider.py +++ b/azure-iot-device/azure/iot/device/iothub/auth/base_renewable_token_authentication_provider.py @@ -10,6 +10,7 @@ import abc import logging import math import six +import weakref from threading import Timer import six.moves.urllib as urllib from .authentication_provider import AuthenticationProvider @@ -60,10 +61,9 @@ class BaseRenewableTokenAuthenticationProvider(AuthenticationProvider): self._token_update_timer = None self.shared_access_key_name = None self.sas_token_str = None - self.on_sas_token_updated_handler = None + self.on_sas_token_updated_handler_list = [] - def disconnect(self): - """Cancel updates to the SAS Token""" + def __del__(self): self._cancel_token_update_timer() def generate_new_sas_token(self): @@ -81,14 +81,14 @@ class BaseRenewableTokenAuthenticationProvider(AuthenticationProvider): If self.token_udpate_callback is set, this callback will be called to notify the pipeline that a new token is available. The pipeline is responsible for doing - whatever is necessary to leverage the new token when the on_sas_token_updated_handler + whatever is necessary to leverage the new token when the on_sas_token_updated_handler_list function is called. The token that is generated expires at some point in the future, based on the token renewal interval and the token renewal margin. When a token is first generated, the authorization provider object will set a timer which will be responsible for renewing the token before the it expires. When this timer fires, it will automatically generate - a new sas token and notify the pipeline by calling self.on_sas_token_updated_handler. + a new sas token and notify the pipeline by calling self.on_sas_token_updated_handler_list. The token update timer is set based on two numbers: self.token_validity_period and self.token_renewal_margin @@ -144,7 +144,11 @@ class BaseRenewableTokenAuthenticationProvider(AuthenticationProvider): t = self._token_update_timer self._token_update_timer = None if t: - logger.debug("Canceling token update timer for (%s,%s)", self.device_id, self.module_id) + logger.debug( + "Canceling token update timer for (%s,%s)", + self.device_id, + self.module_id if self.module_id else "", + ) t.cancel() def _schedule_token_update(self, seconds_until_update): @@ -160,9 +164,30 @@ class BaseRenewableTokenAuthenticationProvider(AuthenticationProvider): seconds_until_update, ) + # It's important to use a weak reference to self inside this timer function + # because we don't want the timer to prevent this object (`self`) from being collected. + # + # We want `self` to get collected when the pipeline gets collected, and + # we want the pipeline to get collected when the client object gets collected. + # This way, everything gets cleaned up when the user is done with the client object, + # as expected. + # + # If timerfunc used `self` directly, that would be a strong reference, and that strong + # reference would prevent `self` from being collected as long as the timer existed. + # + # If this isn't collected when the client is collected, then the object that implements the + # on_sas_token_updated_hndler doesn't get collected. Since that object is part of the + # pipeline, a major part of the pipeline ends up staying around, probably orphaned from + # the client. Since that orphaned part of the pipeline contains Paho, bad things can happen + # if we don't clean up Paho correctly. This is especially noticable if one process + # destroys a client object and creates a new one. + # + self_weakref = weakref.ref(self) + def timerfunc(): - logger.debug("Timed SAS update for (%s,%s)", self.device_id, self.module_id) - self.generate_new_sas_token() + this = self_weakref() + logger.debug("Timed SAS update for (%s,%s)", this.device_id, this.module_id) + this.generate_new_sas_token() self._token_update_timer = Timer(seconds_until_update, timerfunc) self._token_update_timer.daemon = True @@ -173,14 +198,15 @@ class BaseRenewableTokenAuthenticationProvider(AuthenticationProvider): In response to this event, clients should re-initiate their connection in order to use the updated sas token. """ - if self.on_sas_token_updated_handler: + if bool(len(self.on_sas_token_updated_handler_list)): logger.debug( "sending token update notification for (%s, %s)", self.device_id, self.module_id ) - self.on_sas_token_updated_handler() + for x in self.on_sas_token_updated_handler_list: + x() else: logger.warning( - "_notify_token_updated: on_sas_token_updated_handler not set. Doing nothing." + "_notify_token_updated: on_sas_token_updated_handler_list not set. Doing nothing." ) def get_current_sas_token(self): diff --git a/azure-iot-device/azure/iot/device/iothub/auth/iotedge_authentication_provider.py b/azure-iot-device/azure/iot/device/iothub/auth/iotedge_authentication_provider.py index dca540502..9bdd19e29 100644 --- a/azure-iot-device/azure/iot/device/iothub/auth/iotedge_authentication_provider.py +++ b/azure-iot-device/azure/iot/device/iothub/auth/iotedge_authentication_provider.py @@ -12,14 +12,15 @@ import requests import requests_unixsocket import logging from .base_renewable_token_authentication_provider import BaseRenewableTokenAuthenticationProvider -from azure.iot.device import constant +from azure.iot.device.common.chainable_exception import ChainableException +from azure.iot.device.product_info import ProductInfo requests_unixsocket.monkeypatch() logger = logging.getLogger(__name__) -class IoTEdgeError(Exception): +class IoTEdgeError(ChainableException): pass @@ -56,7 +57,7 @@ class IoTEdgeAuthenticationProvider(BaseRenewableTokenAuthenticationProvider): workload_uri=workload_uri, ) self.gateway_hostname = gateway_hostname - self.ca_cert = self.hsm.get_trust_bundle() + self.server_verification_cert = self.hsm.get_trust_bundle() # TODO: reconsider this design when refactoring the BaseRenewableToken auth parent # TODO: Consider handling the quoting within this function, and renaming quoted_resource_uri to resource_uri @@ -107,7 +108,7 @@ class IoTEdgeHsm(object): Return the trust bundle that can be used to validate the server-side SSL TLS connection that we use to talk to edgeHub. - :return: The CA certificate to use for connections to the Azure IoT Edge + :return: The server verification certificate to use for connections to the Azure IoT Edge instance, as a PEM certificate in string form. :raises: IoTEdgeError if unable to retrieve the certificate. @@ -115,23 +116,23 @@ class IoTEdgeHsm(object): r = requests.get( self.workload_uri + "trust-bundle", params={"api-version": self.api_version}, - headers={"User-Agent": urllib.parse.quote_plus(constant.USER_AGENT)}, + headers={"User-Agent": urllib.parse.quote_plus(ProductInfo.get_iothub_user_agent())}, ) # Validate that the request was successful try: r.raise_for_status() - except requests.exceptions.HTTPError: - raise IoTEdgeError("Unable to get trust bundle from EdgeHub") + except requests.exceptions.HTTPError as e: + raise IoTEdgeError(message="Unable to get trust bundle from EdgeHub", cause=e) # Decode the trust bundle try: bundle = r.json() - except ValueError: - raise IoTEdgeError("Unable to decode trust bundle") + except ValueError as e: + raise IoTEdgeError(message="Unable to decode trust bundle", cause=e) # Retrieve the certificate try: cert = bundle["certificate"] - except KeyError: - raise IoTEdgeError("No certificate in trust bundle") + except KeyError as e: + raise IoTEdgeError(message="No certificate in trust bundle", cause=e) return cert def sign(self, data_str): @@ -161,21 +162,21 @@ class IoTEdgeHsm(object): r = requests.post( # TODO: can we use json field instead of data? url=path, params={"api-version": self.api_version}, - headers={"User-Agent": urllib.parse.quote_plus(constant.USER_AGENT)}, + headers={"User-Agent": urllib.parse.quote_plus(ProductInfo.get_iothub_user_agent())}, data=json.dumps(sign_request), ) try: r.raise_for_status() - except requests.exceptions.HTTPError: - raise IoTEdgeError("Unable to sign data") + except requests.exceptions.HTTPError as e: + raise IoTEdgeError(message="Unable to sign data", cause=e) try: sign_response = r.json() - except ValueError: - raise IoTEdgeError("Unable to decode signed data") + except ValueError as e: + raise IoTEdgeError(message="Unable to decode signed data", cause=e) try: signed_data_str = sign_response["digest"] - except KeyError: - raise IoTEdgeError("No signed data received") + except KeyError as e: + raise IoTEdgeError(message="No signed data received", cause=e) return urllib.parse.quote(signed_data_str) diff --git a/azure-iot-device/azure/iot/device/iothub/auth/sk_authentication_provider.py b/azure-iot-device/azure/iot/device/iothub/auth/sk_authentication_provider.py index 2355a451f..d6d309c36 100644 --- a/azure-iot-device/azure/iot/device/iothub/auth/sk_authentication_provider.py +++ b/azure-iot-device/azure/iot/device/iothub/auth/sk_authentication_provider.py @@ -64,7 +64,7 @@ class SymmetricKeyAuthenticationProvider(BaseRenewableTokenAuthenticationProvide self.shared_access_key = shared_access_key self.shared_access_key_name = shared_access_key_name self.gateway_hostname = gateway_hostname - self.ca_cert = None + self.server_verification_cert = None @staticmethod def parse(connection_string): diff --git a/azure-iot-device/azure/iot/device/iothub/models/message.py b/azure-iot-device/azure/iot/device/iothub/models/message.py index 564c11792..350181e3c 100644 --- a/azure-iot-device/azure/iot/device/iothub/models/message.py +++ b/azure-iot-device/azure/iot/device/iothub/models/message.py @@ -6,6 +6,7 @@ """This module contains a class representing messages that are sent or received. """ from azure.iot.device import constant +import sys # TODO: Revise this class. Does all of this REALLY need to be here? @@ -15,7 +16,7 @@ class Message(object): :ivar data: The data that constitutes the payload :ivar custom_properties: Dictionary of custom message properties :ivar lock_token: Used by receiver to abandon, reject or complete the message - :ivar message id: A user-settlable identifier for the message used for request-reply patterns. Format: A case-sensitive string (up to 128 characters long) of ASCII 7-bit alphanumeric characters + {'-', ':', '.', '+', '%', '_', '#', '*', '?', '!', '(', ')', ',', '=', '@', ';', '$', '''} + :ivar message id: A user-settable identifier for the message used for request-reply patterns. Format: A case-sensitive string (up to 128 characters long) of ASCII 7-bit alphanumeric characters + {'-', ':', '.', '+', '%', '_', '#', '*', '?', '!', '(', ')', ',', '=', '@', ';', '$', '''} :ivar sequence_number: A number (unique per device-queue) assigned by IoT Hub to each message :ivar to: A destination specified for Cloud-to-Device (C2D) messages :ivar expiry_time_utc: Date and time of message expiration in UTC format @@ -36,8 +37,8 @@ class Message(object): :param data: The data that constitutes the payload :param str message_id: A user-settable identifier for the message used for request-reply patterns. Format: A case-sensitive string (up to 128 characters long) of ASCII 7-bit alphanumeric characters + {'-', ':', '.', '+', '%', '_', '#', '*', '?', '!', '(', ')', ',', '=', '@', ';', '$', '''} - :param str content_encoding: Content encoding of the message data. Can be 'utf-8', 'utf-16' or 'utf-32' - :param str content_type: Content type property used to routes with the message body. Can be 'application/json' + :param str content_encoding: Content encoding of the message data. Other values can be utf-16' or 'utf-32' + :param str content_type: Content type property used to routes with the message body. :param str output_name: Name of the output that the is being sent to. """ self.data = data @@ -70,3 +71,16 @@ class Message(object): def __str__(self): return str(self.data) + + def get_size(self): + total = 0 + total = total + sum( + sys.getsizeof(v) + for v in self.__dict__.values() + if v is not None and v is not self.custom_properties + ) + if self.custom_properties: + total = total + sum( + sys.getsizeof(v) for v in self.custom_properties.values() if v is not None + ) + return total diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/__init__.py b/azure-iot-device/azure/iot/device/iothub/pipeline/__init__.py index e2d0ff6ee..27583f849 100644 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/__init__.py +++ b/azure-iot-device/azure/iot/device/iothub/pipeline/__init__.py @@ -6,4 +6,5 @@ INTERNAL USAGE ONLY """ from .iothub_pipeline import IoTHubPipeline -from .edge_pipeline import EdgePipeline +from .http_pipeline import HTTPPipeline +from .config import IoTHubPipelineConfig diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/config.py b/azure-iot-device/azure/iot/device/iothub/pipeline/config.py new file mode 100644 index 000000000..907a9f622 --- /dev/null +++ b/azure-iot-device/azure/iot/device/iothub/pipeline/config.py @@ -0,0 +1,30 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import logging +from azure.iot.device.common.pipeline.config import BasePipelineConfig + +logger = logging.getLogger(__name__) + + +class IoTHubPipelineConfig(BasePipelineConfig): + """A class for storing all configurations/options for IoTHub clients in the Azure IoT Python Device Client Library. + """ + + def __init__(self, product_info="", **kwargs): + """Initializer for IoTHubPipelineConfig which passes all unrecognized keyword-args down to BasePipelineConfig + to be evaluated. This stacked options setting is to allow for unique configuration options to exist between the + IoTHub Client and the Provisioning Client, while maintaining a base configuration class with shared config options. + + :param str product_info: A custom identification string for the type of device connecting to Azure IoT Hub. + """ + super(IoTHubPipelineConfig, self).__init__(**kwargs) + self.product_info = product_info + + # Now, the parameters below are not exposed to the user via kwargs. They need to be set by manipulating the IoTHubPipelineConfig object. + # They are not in the BasePipelineConfig because these do not apply to the provisioning client. + self.blob_upload = False + self.method_invoke = False diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/exceptions.py b/azure-iot-device/azure/iot/device/iothub/pipeline/exceptions.py new file mode 100644 index 000000000..e11682c06 --- /dev/null +++ b/azure-iot-device/azure/iot/device/iothub/pipeline/exceptions.py @@ -0,0 +1,22 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +"""This module defines an exception surface, exposed as part of the pipeline API""" + +# For now, present relevant transport errors as part of the Pipeline API surface +# so that they do not have to be duplicated at this layer. +from azure.iot.device.common.pipeline.pipeline_exceptions import * +from azure.iot.device.common.transport_exceptions import ( + ConnectionFailedError, + ConnectionDroppedError, + # TODO: UnauthorizedError (the one from transport) should probably not surface out of + # the pipeline due to confusion with the higher level service UnauthorizedError. It + # should probably get turned into some other error instead (e.g. ConnectionFailedError). + # But for now, this is a stopgap. + UnauthorizedError, + ProtocolClientError, + TlsExchangeAuthError, + ProtocolProxyError, +) diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/http_map_error.py b/azure-iot-device/azure/iot/device/iothub/pipeline/http_map_error.py new file mode 100644 index 000000000..f67e338a2 --- /dev/null +++ b/azure-iot-device/azure/iot/device/iothub/pipeline/http_map_error.py @@ -0,0 +1,67 @@ +def translate_error(sc, reason): + """ + Codes_SRS_NODE_IOTHUB_REST_API_CLIENT_16_012: [Any error object returned by translate_error shall inherit from the generic Error Javascript object and have 3 properties: + - response shall contain the IncomingMessage object returned by the HTTP layer. + - reponseBody shall contain the content of the HTTP response. + - message shall contain a human-readable error message.] + """ + message = "Error: {}".format(reason) + if sc == 400: + # translate_error shall return an ArgumentError if the HTTP response status code is 400. + error = "ArgumentError({})".format(message) + + elif sc == 401: + # translate_error shall return an UnauthorizedError if the HTTP response status code is 401. + error = "UnauthorizedError({})".format(message) + + elif sc == 403: + # translate_error shall return an TooManyDevicesError if the HTTP response status code is 403. + error = "TooManyDevicesError({})".format(message) + + elif sc == 404: + if reason == "Device Not Found": + # translate_error shall return an DeviceNotFoundError if the HTTP response status code is 404 and if the error code within the body of the error response is DeviceNotFound. + error = "DeviceNotFoundError({})".format(message) + elif reason == "IoTHub Not Found": + # translate_error shall return an IotHubNotFoundError if the HTTP response status code is 404 and if the error code within the body of the error response is IotHubNotFound. + error = "IotHubNotFoundError({})".format(message) + else: + error = "Error('Not found')" + + elif sc == 408: + # translate_error shall return a DeviceTimeoutError if the HTTP response status code is 408. + error = "DeviceTimeoutError({})".format(message) + + elif sc == 409: + # translate_error shall return an DeviceAlreadyExistsError if the HTTP response status code is 409. + error = "DeviceAlreadyExistsError({})".format(message) + + elif sc == 412: + # translate_error shall return an InvalidEtagError if the HTTP response status code is 412. + error = "InvalidEtagError({})".format(message) + + elif sc == 429: + # translate_error shall return an ThrottlingError if the HTTP response status code is 429.] + error = "ThrottlingError({})".format(message) + + elif sc == 500: + # translate_error shall return an InternalServerError if the HTTP response status code is 500. + error = "InternalServerError({})".format(message) + + elif sc == 502: + # translate_error shall return a BadDeviceResponseError if the HTTP response status code is 502. + error = "BadDeviceResponseError({})".format(message) + + elif sc == 503: + # translate_error shall return an ServiceUnavailableError if the HTTP response status code is 503. + error = "ServiceUnavailableError({})".format(message) + + elif sc == 504: + # translate_error shall return a GatewayTimeoutError if the HTTP response status code is 504. + error = "GatewayTimeoutError({})".format(message) + + else: + # If the HTTP error code is unknown, translate_error should return a generic Javascript Error object. + error = "Error({})".format(message) + + return error diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/http_path_iothub.py b/azure-iot-device/azure/iot/device/iothub/pipeline/http_path_iothub.py new file mode 100644 index 000000000..ef3e711bf --- /dev/null +++ b/azure-iot-device/azure/iot/device/iothub/pipeline/http_path_iothub.py @@ -0,0 +1,44 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import logging +import six.moves.urllib as urllib + +logger = logging.getLogger(__name__) + + +def get_method_invoke_path(device_id, module_id=None): + """ + :return: The path for invoking methods from one module to a device or module. It is of the format + twins/uri_encode($device_id)/modules/uri_encode($module_id)/methods + """ + if module_id: + return "twins/{device_id}/modules/{module_id}/methods".format( + device_id=urllib.parse.quote_plus(device_id), + module_id=urllib.parse.quote_plus(module_id), + ) + else: + return "twins/{device_id}/methods".format(device_id=urllib.parse.quote_plus(device_id)) + + +def get_storage_info_for_blob_path(device_id): + """ + This does not take a module_id since get_storage_info_for_blob_path should only ever be invoked on device clients. + + :return: The path for getting the storage sdk credential information from IoT Hub. It is of the format + devices/uri_encode($device_id)/files + """ + return "devices/{}/files".format(urllib.parse.quote_plus(device_id)) + + +def get_notify_blob_upload_status_path(device_id): + """ + This does not take a module_id since get_notify_blob_upload_status_path should only ever be invoked on device clients. + + :return: The path for getting the storage sdk credential information from IoT Hub. It is of the format + devices/uri_encode($device_id)/files/notifications + """ + return "devices/{}/files/notifications".format(urllib.parse.quote_plus(device_id)) diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/http_pipeline.py b/azure-iot-device/azure/iot/device/iothub/pipeline/http_pipeline.py new file mode 100644 index 000000000..235219ba1 --- /dev/null +++ b/azure-iot-device/azure/iot/device/iothub/pipeline/http_pipeline.py @@ -0,0 +1,170 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import logging +import sys +from azure.iot.device.common.evented_callback import EventedCallback +from azure.iot.device.common.pipeline import ( + pipeline_stages_base, + pipeline_ops_base, + pipeline_stages_http, +) + +from azure.iot.device.iothub.pipeline import exceptions as pipeline_exceptions + +from . import ( + constant, + pipeline_stages_iothub, + pipeline_ops_iothub, + pipeline_ops_iothub_http, + pipeline_stages_iothub_http, +) +from azure.iot.device.iothub.auth.x509_authentication_provider import X509AuthenticationProvider + +logger = logging.getLogger(__name__) + + +class HTTPPipeline(object): + """Pipeline to communicate with Edge. + Uses HTTP. + """ + + def __init__(self, auth_provider, pipeline_configuration): + """ + Constructor for instantiating a pipeline adapter object. + + :param auth_provider: The authentication provider + :param pipeline_configuration: The configuration generated based on user inputs + """ + self._pipeline = ( + pipeline_stages_base.PipelineRootStage(pipeline_configuration=pipeline_configuration) + .append_stage(pipeline_stages_iothub.UseAuthProviderStage()) + .append_stage(pipeline_stages_iothub_http.IoTHubHTTPTranslationStage()) + .append_stage(pipeline_stages_http.HTTPTransportStage()) + ) + + callback = EventedCallback() + + if isinstance(auth_provider, X509AuthenticationProvider): + op = pipeline_ops_iothub.SetX509AuthProviderOperation( + auth_provider=auth_provider, callback=callback + ) + else: # Currently everything else goes via this block. + op = pipeline_ops_iothub.SetAuthProviderOperation( + auth_provider=auth_provider, callback=callback + ) + + self._pipeline.run_op(op) + callback.wait_for_completion() + + def invoke_method(self, device_id, method_params, callback, module_id=None): + """ + Send a request to the service to invoke a method on a target device or module. + + :param device_id: The target device id + :param method_params: The method parameters to be invoked on the target client + :param callback: callback which is called when request has been fulfilled. + On success, this callback is called with the error=None. + On failure, this callback is called with error set to the cause of the failure. + :param module_id: The target module id + + The following exceptions are not "raised", but rather returned via the "error" parameter + when invoking "callback": + + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ProtocolClientError` + """ + logger.debug("IoTHubPipeline invoke_method called") + if not self._pipeline.pipeline_configuration.method_invoke: + # If this parameter is not set, that means that the pipeline was not generated by the edge environment. Method invoke only works for clients generated using the edge environment. + error = pipeline_exceptions.PipelineError( + "invoke_method called, but it is only supported on module clients generated from an edge environment. If you are not using a module generated from an edge environment, you cannot use invoke_method" + ) + return callback(error=error) + + def on_complete(op, error): + callback(error=error, invoke_method_response=op.method_response) + + self._pipeline.run_op( + pipeline_ops_iothub_http.MethodInvokeOperation( + target_device_id=device_id, + target_module_id=module_id, + method_params=method_params, + callback=on_complete, + ) + ) + + def get_storage_info_for_blob(self, blob_name, callback): + """ + Sends a POST request to the IoT Hub service endpoint to retrieve an object that contains information for uploading via the Storage SDK. + + :param blob_name: The name of the blob that will be uploaded via the Azure Storage SDK. + :param callback: callback which is called when request has been fulfilled. + On success, this callback is called with the error=None, and the storage_info set to the information JSON received from the service. + On failure, this callback is called with error set to the cause of the failure, and the storage_info=None. + + The following exceptions are not "raised", but rather returned via the "error" parameter + when invoking "callback": + + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ProtocolClientError` + """ + logger.debug("IoTHubPipeline get_storage_info_for_blob called") + if not self._pipeline.pipeline_configuration.blob_upload: + # If this parameter is not set, that means this is not a device client. Upload to blob is not supported on module clients. + error = pipeline_exceptions.PipelineError( + "get_storage_info_for_blob called, but it is only supported for use with device clients. Ensure you are using a device client." + ) + return callback(error=error) + + def on_complete(op, error): + callback(error=error, storage_info=op.storage_info) + + self._pipeline.run_op( + pipeline_ops_iothub_http.GetStorageInfoOperation( + blob_name=blob_name, callback=on_complete + ) + ) + + def notify_blob_upload_status( + self, correlation_id, is_success, status_code, status_description, callback + ): + """ + Sends a POST request to a IoT Hub service endpoint to notify the status of the Storage SDK call for a blob upload. + + :param str correlation_id: Provided by IoT Hub on get_storage_info_for_blob request. + :param bool is_success: A boolean that indicates whether the file was uploaded successfully. + :param int status_code: A numeric status code that is the status for the upload of the fiel to storage. + :param str status_description: A description that corresponds to the status_code. + + :param callback: callback which is called when request has been fulfilled. + On success, this callback is called with the error=None. + On failure, this callback is called with error set to the cause of the failure. + + + The following exceptions are not "raised", but rather returned via the "error" parameter + when invoking "callback": + + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ProtocolClientError` + """ + logger.debug("IoTHubPipeline notify_blob_upload_status called") + if not self._pipeline.pipeline_configuration.blob_upload: + # If this parameter is not set, that means this is not a device client. Upload to blob is not supported on module clients. + error = pipeline_exceptions.PipelineError( + "notify_blob_upload_status called, but it is only supported for use with device clients. Ensure you are using a device client." + ) + return callback(error=error) + + def on_complete(op, error): + callback(error=error) + + self._pipeline.run_op( + pipeline_ops_iothub_http.NotifyBlobUploadStatusOperation( + correlation_id=correlation_id, + is_success=is_success, + status_code=status_code, + status_description=status_description, + callback=on_complete, + ) + ) diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/iothub_pipeline.py b/azure-iot-device/azure/iot/device/iothub/pipeline/iothub_pipeline.py index 789262254..b89092b98 100644 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/iothub_pipeline.py +++ b/azure-iot-device/azure/iot/device/iothub/pipeline/iothub_pipeline.py @@ -25,11 +25,13 @@ logger = logging.getLogger(__name__) class IoTHubPipeline(object): - def __init__(self, auth_provider): + def __init__(self, auth_provider, pipeline_configuration): """ Constructor for instantiating a pipeline adapter object :param auth_provider: The authentication provider + :param pipeline_configuration: The configuration generated based on user inputs """ + self.feature_enabled = { constant.C2D_MSG: False, constant.INPUT_MSG: False, @@ -46,14 +48,70 @@ class IoTHubPipeline(object): self.on_method_request_received = None self.on_twin_patch_received = None + # Currently a single timeout stage and a single retry stage for MQTT retry only. + # Later, a higher level timeout and a higher level retry stage. self._pipeline = ( - pipeline_stages_base.PipelineRootStage() + # + # The root is always the root. By definition, it's the first stage in the pipeline. + # + pipeline_stages_base.PipelineRootStage(pipeline_configuration=pipeline_configuration) + # + # UseAuthProviderStage comes near the root by default because it doesn't need to be after + # anything, but it does need to be before IoTHubMQTTTranslationStage. + # .append_stage(pipeline_stages_iothub.UseAuthProviderStage()) - .append_stage(pipeline_stages_iothub.HandleTwinOperationsStage()) + # + # TwinRequestResponseStage comes near the root by default because it doesn't need to be + # after anything + # + .append_stage(pipeline_stages_iothub.TwinRequestResponseStage()) + # + # CoordinateRequestAndResponseStage needs to be after TwinRequestResponseStage because + # TwinRequestResponseStage creates the request ops that CoordinateRequestAndResponseStage + # is coordinating. It needs to be before IoTHubMQTTTranslationStage because that stage + # operates on ops that CoordinateRequestAndResponseStage produces + # .append_stage(pipeline_stages_base.CoordinateRequestAndResponseStage()) - .append_stage(pipeline_stages_iothub_mqtt.IoTHubMQTTConverterStage()) - .append_stage(pipeline_stages_base.EnsureConnectionStage()) - .append_stage(pipeline_stages_base.SerializeConnectOpsStage()) + # + # IoTHubMQTTTranslationStage comes here because this is the point where we can translate + # all operations directly into MQTT. After this stage, only pipeline_stages_base stages + # are allowed because IoTHubMQTTTranslationStage removes all the IoTHub-ness from the ops + # + .append_stage(pipeline_stages_iothub_mqtt.IoTHubMQTTTranslationStage()) + # + # AutoConnectStage comes here because only MQTT ops have the need_connection flag set + # and this is the first place in the pipeline wherer we can guaranetee that all network + # ops are MQTT ops. + # + .append_stage(pipeline_stages_base.AutoConnectStage()) + # + # ReconnectStage needs to be after AutoConnectStage because ReconnectStage sets/clears + # the virtually_conencted flag and we want an automatic connection op to set this flag so + # we can reconnect autoconnect operations. This is important, for example, if a + # send_message causes the transport to automatically connect, but that connection fails. + # When that happens, the ReconenctState will hold onto the ConnectOperation until it + # succeeds, and only then will return success to the AutoConnectStage which will + # allow the publish to continue. + # + .append_stage(pipeline_stages_base.ReconnectStage()) + # + # ConnectionLockStage needs to be after ReconnectStage because we want any ops that + # ReconnectStage creates to go through the ConnectionLockStage gate + # + .append_stage(pipeline_stages_base.ConnectionLockStage()) + # + # RetryStage needs to be near the end because it's retrying low-level MQTT operations. + # + .append_stage(pipeline_stages_base.RetryStage()) + # + # OpTimeoutStage needs to be after RetryStage because OpTimeoutStage returns the timeout + # errors that RetryStage is watching for. + # + .append_stage(pipeline_stages_base.OpTimeoutStage()) + # + # MQTTTransportStage needs to be at the very end of the pipeline because this is where + # operations turn into network traffic + # .append_stage(pipeline_stages_mqtt.MQTTTransportStage()) ) @@ -110,23 +168,25 @@ class IoTHubPipeline(object): self._pipeline.run_op(op) callback.wait_for_completion() - if op.error: - logger.error("{} failed: {}".format(op.name, op.error)) - raise op.error def connect(self, callback): """ Connect to the service. :param callback: callback which is called when the connection to the service is complete. + + The following exceptions are not "raised", but rather returned via the "error" parameter + when invoking "callback": + + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionFailedError` + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionDroppedError` + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.UnauthorizedError` + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ProtocolClientError` """ logger.debug("Starting ConnectOperation on the pipeline") - def on_complete(op): - if op.error: - callback(error=op.error) - else: - callback() + def on_complete(op, error): + callback(error=error) self._pipeline.run_op(pipeline_ops_base.ConnectOperation(callback=on_complete)) @@ -135,14 +195,16 @@ class IoTHubPipeline(object): Disconnect from the service. :param callback: callback which is called when the connection to the service has been disconnected + + The following exceptions are not "raised", but rather returned via the "error" parameter + when invoking "callback": + + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ProtocolClientError` """ logger.debug("Starting DisconnectOperation on the pipeline") - def on_complete(op): - if op.error: - callback(error=op.error) - else: - callback() + def on_complete(op, error): + callback(error=error) self._pipeline.run_op(pipeline_ops_base.DisconnectOperation(callback=on_complete)) @@ -152,13 +214,18 @@ class IoTHubPipeline(object): :param message: message to send. :param callback: callback which is called when the message publish has been acknowledged by the service. + + The following exceptions are not "raised", but rather returned via the "error" parameter + when invoking "callback": + + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionFailedError` + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionDroppedError` + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.UnauthorizedError` + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ProtocolClientError` """ - def on_complete(op): - if op.error: - callback(error=op.error) - else: - callback() + def on_complete(op, error): + callback(error=error) self._pipeline.run_op( pipeline_ops_iothub.SendD2CMessageOperation(message=message, callback=on_complete) @@ -170,13 +237,18 @@ class IoTHubPipeline(object): :param message: message to send. :param callback: callback which is called when the message publish has been acknowledged by the service. + + The following exceptions are not "raised", but rather returned via the "error" parameter + when invoking "callback": + + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionFailedError` + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionDroppedError` + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.UnauthorizedError` + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ProtocolClientError` """ - def on_complete(op): - if op.error: - callback(error=op.error) - else: - callback() + def on_complete(op, error): + callback(error=error) self._pipeline.run_op( pipeline_ops_iothub.SendOutputEventOperation(message=message, callback=on_complete) @@ -188,14 +260,19 @@ class IoTHubPipeline(object): :param method_response: the method response to send :param callback: callback which is called when response has been acknowledged by the service + + The following exceptions are not "raised", but rather returned via the "error" parameter + when invoking "callback": + + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionFailedError` + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionDroppedError` + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.UnauthorizedError` + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ProtocolClientError` """ logger.debug("IoTHubPipeline send_method_response called") - def on_complete(op): - if op.error: - callback(error=op.error) - else: - callback() + def on_complete(op, error): + callback(error=error) self._pipeline.run_op( pipeline_ops_iothub.SendMethodResponseOperation( @@ -208,12 +285,22 @@ class IoTHubPipeline(object): Send a request for a full twin to the service. :param callback: callback which is called when request has been acknowledged by the service. - This callback should have one parameter, which will contain the requested twin when called. + This callback should have two parameters. On success, this callback is called with the + requested twin and error=None. On failure, this callback is called with None for the requested + twin and error set to the cause of the failure. + + The following exceptions are not "raised", but rather returned via the "error" parameter + when invoking "callback": + + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionFailedError` + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionDroppedError` + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.UnauthorizedError` + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ProtocolClientError` """ - def on_complete(op): - if op.error: - callback(error=op.error, twin=None) + def on_complete(op, error): + if error: + callback(error=error, twin=None) else: callback(twin=op.twin) @@ -225,13 +312,18 @@ class IoTHubPipeline(object): :param patch: the reported properties patch to send :param callback: callback which is called when request has been acknowledged by the service. + + The following exceptions are not "raised", but rather returned via the "error" parameter + when invoking "callback": + + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionFailedError` + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ConnectionDroppedError` + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.UnauthorizedError` + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ProtocolClientError` """ - def on_complete(op): - if op.error: - callback(error=op.error) - else: - callback() + def on_complete(op, error): + callback(error=error) self._pipeline.run_op( pipeline_ops_iothub.PatchTwinReportedPropertiesOperation( @@ -253,11 +345,8 @@ class IoTHubPipeline(object): raise ValueError("Invalid feature_name") self.feature_enabled[feature_name] = True - def on_complete(op): - if op.error: - callback(error=op.error) - else: - callback() + def on_complete(op, error): + callback(error=error) self._pipeline.run_op( pipeline_ops_base.EnableFeatureOperation( @@ -279,14 +368,18 @@ class IoTHubPipeline(object): raise ValueError("Invalid feature_name") self.feature_enabled[feature_name] = False - def on_complete(op): - if op.error: - callback(error=op.error) - else: - callback() + def on_complete(op, error): + callback(error=error) self._pipeline.run_op( pipeline_ops_base.DisableFeatureOperation( feature_name=feature_name, callback=on_complete ) ) + + @property + def connected(self): + """ + Read-only property to indicate if the transport is connected or not. + """ + return self._pipeline.connected diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_ops_iothub.py b/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_ops_iothub.py index 13b22d392..117a8c731 100644 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_ops_iothub.py +++ b/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_ops_iothub.py @@ -18,15 +18,15 @@ class SetX509AuthProviderOperation(PipelineOperation): very IoTHub-specific """ - def __init__(self, auth_provider, callback=None): + def __init__(self, auth_provider, callback): """ Initializer for SetAuthProviderOperation objects. :param object auth_provider: The X509 authorization provider object to use to retrieve connection parameters which can be used to connect to the service. :param Function callback: The function that gets called when this operation is complete or has failed. - The callback function must accept A PipelineOperation object which indicates the specific operation which - has completed or failed. + The callback function must accept A PipelineOperation object which indicates the specific operation which + has completed or failed. """ super(SetX509AuthProviderOperation, self).__init__(callback=callback) self.auth_provider = auth_provider @@ -42,7 +42,7 @@ class SetAuthProviderOperation(PipelineOperation): very IoTHub-specific """ - def __init__(self, auth_provider, callback=None): + def __init__(self, auth_provider, callback): """ Initializer for SetAuthProviderOperation objects. @@ -69,12 +69,12 @@ class SetIoTHubConnectionArgsOperation(PipelineOperation): self, device_id, hostname, + callback, module_id=None, gateway_hostname=None, - ca_cert=None, + server_verification_cert=None, client_cert=None, sas_token=None, - callback=None, ): """ Initializer for SetIoTHubConnectionArgsOperation objects. @@ -85,8 +85,8 @@ class SetIoTHubConnectionArgsOperation(PipelineOperation): for the module we are connecting. :param str gateway_hostname: (optional) If we are going through a gateway host, this is the hostname for the gateway - :param str ca_cert: (Optional) The CA certificate to use if the server that we're going to - connect to uses server-side TLS + :param str server_verification_cert: (Optional) The server verification certificate to use + if the server that we're going to connect to uses server-side TLS :param X509 client_cert: (Optional) The x509 object containing a client certificate and key used to connect to the service :param str sas_token: The token string which will be used to authenticate with the service @@ -99,7 +99,7 @@ class SetIoTHubConnectionArgsOperation(PipelineOperation): self.module_id = module_id self.hostname = hostname self.gateway_hostname = gateway_hostname - self.ca_cert = ca_cert + self.server_verification_cert = server_verification_cert self.client_cert = client_cert self.sas_token = sas_token @@ -111,7 +111,7 @@ class SendD2CMessageOperation(PipelineOperation): This operation is in the group of IoTHub operations because it is very specific to the IoTHub client """ - def __init__(self, message, callback=None): + def __init__(self, message, callback): """ Initializer for SendD2CMessageOperation objects. @@ -131,7 +131,7 @@ class SendOutputEventOperation(PipelineOperation): This operation is in the group of IoTHub operations because it is very specific to the IoTHub client """ - def __init__(self, message, callback=None): + def __init__(self, message, callback): """ Initializer for SendOutputEventOperation objects. @@ -152,7 +152,7 @@ class SendMethodResponseOperation(PipelineOperation): This operation is in the group of IoTHub operations because it is very specific to the IoTHub client. """ - def __init__(self, method_response, callback=None): + def __init__(self, method_response, callback): """ Initializer for SendMethodResponseOperation objects. @@ -176,7 +176,7 @@ class GetTwinOperation(PipelineOperation): :type twin: Twin """ - def __init__(self, callback=None): + def __init__(self, callback): """ Initializer for GetTwinOperation objects. """ @@ -190,7 +190,7 @@ class PatchTwinReportedPropertiesOperation(PipelineOperation): IoT Hub or Azure IoT Edge Hub service. """ - def __init__(self, patch, callback=None): + def __init__(self, patch, callback): """ Initializer for PatchTwinReportedPropertiesOperation object diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_ops_iothub_http.py b/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_ops_iothub_http.py new file mode 100644 index 000000000..d246e24f7 --- /dev/null +++ b/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_ops_iothub_http.py @@ -0,0 +1,79 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from azure.iot.device.common.pipeline import PipelineOperation + + +class MethodInvokeOperation(PipelineOperation): + """ + A PipleineOperation object which contains arguments used to send a method invoke to an IoTHub or EdgeHub server. + + This operation is in the group of EdgeHub operations because it is very specific to the EdgeHub client. + """ + + def __init__(self, target_device_id, target_module_id, method_params, callback): + """ + Initializer for MethodInvokeOperation objects. + + :param str target_device_id: The device id of the target device/module + :param str target_module_id: The module id of the target module + :param method_params: The parameters used to invoke the method, as defined by the IoT Hub specification. + :param callback: The function that gets called when this operation is complete or has failed. + The callback function must accept a PipelineOperation object which indicates the specific operation has which + has completed or failed. + :type callback: Function/callable + """ + super(MethodInvokeOperation, self).__init__(callback=callback) + self.target_device_id = target_device_id + self.target_module_id = target_module_id + self.method_params = method_params + self.method_response = None + + +class GetStorageInfoOperation(PipelineOperation): + """ + A PipleineOperation object which contains arguments used to get the storage information from IoT Hub. + """ + + def __init__(self, blob_name, callback): + """ + Initializer for GetStorageInfo objects. + + :param str blob_name: The name of the blob that will be created in Azure Storage + :param callback: The function that gets called when this operation is complete or has failed. + The callback function must accept a PipelineOperation object which indicates the specific operation has which + has completed or failed. + :type callback: Function/callable + + :ivar dict storage_info: Upon completion, this contains the storage information which was retrieved from the service. + """ + super(GetStorageInfoOperation, self).__init__(callback=callback) + self.blob_name = blob_name + self.storage_info = None + + +class NotifyBlobUploadStatusOperation(PipelineOperation): + """ + A PipleineOperation object which contains arguments used to get the storage information from IoT Hub. + """ + + def __init__(self, correlation_id, is_success, status_code, status_description, callback): + """ + Initializer for GetStorageInfo objects. + + :param str correlation_id: Provided by IoT Hub on get_storage_info_for_blob request. + :param bool is_success: A boolean that indicates whether the file was uploaded successfully. + :param int request_status_code: A numeric status code that is the status for the upload of the fiel to storage. + :param str status_description: A description that corresponds to the status_code. + :param callback: The function that gets called when this operation is complete or has failed. + The callback function must accept a PipelineOperation object which indicates the specific operation has which + has completed or failed. + :type callback: Function/callable + """ + super(NotifyBlobUploadStatusOperation, self).__init__(callback=callback) + self.correlation_id = correlation_id + self.is_success = is_success + self.request_status_code = status_code + self.status_description = status_description diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub.py b/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub.py index 53fa1bf92..2b7d67d97 100644 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub.py +++ b/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub.py @@ -6,13 +6,10 @@ import json import logging -from azure.iot.device.common.pipeline import ( - pipeline_ops_base, - PipelineStage, - operation_flow, - pipeline_thread, -) -from azure.iot.device.common import unhandled_exceptions +from azure.iot.device.common.pipeline import pipeline_ops_base, PipelineStage, pipeline_thread +from azure.iot.device import exceptions +from azure.iot.device.common import handle_exceptions +from azure.iot.device.common.callable_weak_method import CallableWeakMethod from . import pipeline_ops_iothub from . import constant @@ -32,138 +29,146 @@ class UseAuthProviderStage(PipelineStage): """ @pipeline_thread.runs_on_pipeline_thread - def _execute_op(self, op): + def _run_op(self, op): if isinstance(op, pipeline_ops_iothub.SetAuthProviderOperation): self.auth_provider = op.auth_provider - self.auth_provider.on_sas_token_updated_handler = self.on_sas_token_updated - operation_flow.delegate_to_different_op( - stage=self, - original_op=op, - new_op=pipeline_ops_iothub.SetIoTHubConnectionArgsOperation( - device_id=self.auth_provider.device_id, - module_id=getattr(self.auth_provider, "module_id", None), - hostname=self.auth_provider.hostname, - gateway_hostname=getattr(self.auth_provider, "gateway_hostname", None), - ca_cert=getattr(self.auth_provider, "ca_cert", None), - sas_token=self.auth_provider.get_current_sas_token(), - ), + # Here we append rather than just add it to the handler value because otherwise it + # would overwrite the handler from another pipeline that might be using the same auth provider. + self.auth_provider.on_sas_token_updated_handler_list.append( + CallableWeakMethod(self, "_on_sas_token_updated") ) + worker_op = op.spawn_worker_op( + worker_op_type=pipeline_ops_iothub.SetIoTHubConnectionArgsOperation, + device_id=self.auth_provider.device_id, + module_id=self.auth_provider.module_id, + hostname=self.auth_provider.hostname, + gateway_hostname=getattr(self.auth_provider, "gateway_hostname", None), + server_verification_cert=getattr( + self.auth_provider, "server_verification_cert", None + ), + sas_token=self.auth_provider.get_current_sas_token(), + ) + self.send_op_down(worker_op) + elif isinstance(op, pipeline_ops_iothub.SetX509AuthProviderOperation): self.auth_provider = op.auth_provider - operation_flow.delegate_to_different_op( - stage=self, - original_op=op, - new_op=pipeline_ops_iothub.SetIoTHubConnectionArgsOperation( - device_id=self.auth_provider.device_id, - module_id=getattr(self.auth_provider, "module_id", None), - hostname=self.auth_provider.hostname, - gateway_hostname=getattr(self.auth_provider, "gateway_hostname", None), - ca_cert=getattr(self.auth_provider, "ca_cert", None), - client_cert=self.auth_provider.get_x509_certificate(), + worker_op = op.spawn_worker_op( + worker_op_type=pipeline_ops_iothub.SetIoTHubConnectionArgsOperation, + device_id=self.auth_provider.device_id, + module_id=self.auth_provider.module_id, + hostname=self.auth_provider.hostname, + gateway_hostname=getattr(self.auth_provider, "gateway_hostname", None), + server_verification_cert=getattr( + self.auth_provider, "server_verification_cert", None ), + client_cert=self.auth_provider.get_x509_certificate(), ) + self.send_op_down(worker_op) else: - operation_flow.pass_op_to_next_stage(self, op) + super(UseAuthProviderStage, self)._run_op(op) @pipeline_thread.invoke_on_pipeline_thread_nowait - def on_sas_token_updated(self): + def _on_sas_token_updated(self): logger.info( "{}: New sas token received. Passing down UpdateSasTokenOperation.".format(self.name) ) @pipeline_thread.runs_on_pipeline_thread - def on_token_update_complete(op): - if op.error: + def on_token_update_complete(op, error): + if error: logger.error( "{}({}): token update operation failed. Error={}".format( - self.name, op.name, op.error + self.name, op.name, error ) ) - unhandled_exceptions.exception_caught_in_background_thread(op.error) + handle_exceptions.handle_background_exception(error) else: logger.debug( "{}({}): token update operation is complete".format(self.name, op.name) ) - operation_flow.pass_op_to_next_stage( - stage=self, - op=pipeline_ops_base.UpdateSasTokenOperation( + self.send_op_down( + pipeline_ops_base.UpdateSasTokenOperation( sas_token=self.auth_provider.get_current_sas_token(), callback=on_token_update_complete, - ), + ) ) -class HandleTwinOperationsStage(PipelineStage): +class TwinRequestResponseStage(PipelineStage): """ PipelineStage which handles twin operations. In particular, it converts twin GET and PATCH - operations into SendIotRequestAndWaitForResponseOperation operations. This is done at the IoTHub level because + operations into RequestAndResponseOperation operations. This is done at the IoTHub level because there is nothing protocol-specific about this code. The protocol-specific implementation - for twin requests and responses is handled inside IoTHubMQTTConverterStage, when it converts - the SendIotRequestOperation to a protocol-specific send operation and when it converts the - protocol-specific receive event into an IotResponseEvent event. + for twin requests and responses is handled inside IoTHubMQTTTranslationStage, when it converts + the RequestOperation to a protocol-specific send operation and when it converts the + protocol-specific receive event into an ResponseEvent event. """ @pipeline_thread.runs_on_pipeline_thread - def _execute_op(self, op): - def map_twin_error(original_op, twin_op): - if twin_op.error: - original_op.error = twin_op.error + def _run_op(self, op): + def map_twin_error(error, twin_op): + if error: + return error elif twin_op.status_code >= 300: # TODO map error codes to correct exceptions logger.error("Error {} received from twin operation".format(twin_op.status_code)) logger.error("response body: {}".format(twin_op.response_body)) - original_op.error = Exception( + return exceptions.ServiceError( "twin operation returned status {}".format(twin_op.status_code) ) if isinstance(op, pipeline_ops_iothub.GetTwinOperation): - def on_twin_response(twin_op): - logger.debug("{}({}): Got response for GetTwinOperation".format(self.name, op.name)) - map_twin_error(original_op=op, twin_op=twin_op) - if not twin_op.error: - op.twin = json.loads(twin_op.response_body.decode("utf-8")) - operation_flow.complete_op(self, op) + # Alias to avoid overload within the callback below + # CT-TODO: remove the need for this with better callback semantics + op_waiting_for_response = op - operation_flow.pass_op_to_next_stage( - self, - pipeline_ops_base.SendIotRequestAndWaitForResponseOperation( + def on_twin_response(op, error): + logger.debug("{}({}): Got response for GetTwinOperation".format(self.name, op.name)) + error = map_twin_error(error=error, twin_op=op) + if not error: + op_waiting_for_response.twin = json.loads(op.response_body.decode("utf-8")) + op_waiting_for_response.complete(error=error) + + self.send_op_down( + pipeline_ops_base.RequestAndResponseOperation( request_type=constant.TWIN, method="GET", resource_location="/", request_body=" ", callback=on_twin_response, - ), + ) ) elif isinstance(op, pipeline_ops_iothub.PatchTwinReportedPropertiesOperation): - def on_twin_response(twin_op): + # Alias to avoid overload within the callback below + # CT-TODO: remove the need for this with better callback semantics + op_waiting_for_response = op + + def on_twin_response(op, error): logger.debug( "{}({}): Got response for PatchTwinReportedPropertiesOperation operation".format( self.name, op.name ) ) - map_twin_error(original_op=op, twin_op=twin_op) - operation_flow.complete_op(self, op) + error = map_twin_error(error=error, twin_op=op) + op_waiting_for_response.complete(error=error) logger.debug( "{}({}): Sending reported properties patch: {}".format(self.name, op.name, op.patch) ) - operation_flow.pass_op_to_next_stage( - self, - ( - pipeline_ops_base.SendIotRequestAndWaitForResponseOperation( - request_type=constant.TWIN, - method="PATCH", - resource_location="/properties/reported/", - request_body=json.dumps(op.patch), - callback=on_twin_response, - ) - ), + self.send_op_down( + pipeline_ops_base.RequestAndResponseOperation( + request_type=constant.TWIN, + method="PATCH", + resource_location="/properties/reported/", + request_body=json.dumps(op.patch), + callback=on_twin_response, + ) ) else: - operation_flow.pass_op_to_next_stage(self, op) + super(TwinRequestResponseStage, self)._run_op(op) diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub_http.py b/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub_http.py new file mode 100644 index 000000000..7b156488b --- /dev/null +++ b/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub_http.py @@ -0,0 +1,225 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import logging +import json +import six.moves.urllib as urllib +from azure.iot.device.common.pipeline import ( + pipeline_events_base, + pipeline_ops_base, + pipeline_ops_http, + PipelineStage, + pipeline_thread, +) +from . import pipeline_ops_iothub, pipeline_ops_iothub_http, http_path_iothub, http_map_error +from azure.iot.device import exceptions +from azure.iot.device import constant as pkg_constant +from azure.iot.device.product_info import ProductInfo + + +logger = logging.getLogger(__name__) + + +@pipeline_thread.runs_on_pipeline_thread +def map_http_error(error, http_op): + if error: + return error + elif http_op.status_code >= 300: + translated_error = http_map_error.translate_error(http_op.status_code, http_op.reason) + return exceptions.ServiceError( + "HTTP operation returned: {} {}".format(http_op.status_code, translated_error) + ) + + +class IoTHubHTTPTranslationStage(PipelineStage): + """ + PipelineStage which converts other Iot and EdgeHub operations into HTTP operations. This stage also + converts http pipeline events into Iot and EdgeHub pipeline events. + """ + + def __init__(self): + super(IoTHubHTTPTranslationStage, self).__init__() + self.device_id = None + self.module_id = None + self.hostname = None + + @pipeline_thread.runs_on_pipeline_thread + def _run_op(self, op): + if isinstance(op, pipeline_ops_iothub.SetIoTHubConnectionArgsOperation): + self.device_id = op.device_id + self.module_id = op.module_id + + if op.gateway_hostname: + logger.debug( + "Gateway Hostname Present. Setting Hostname to: {}".format(op.gateway_hostname) + ) + self.hostname = op.gateway_hostname + else: + logger.debug( + "Gateway Hostname not present. Setting Hostname to: {}".format( + op.gateway_hostname + ) + ) + self.hostname = op.hostname + worker_op = op.spawn_worker_op( + worker_op_type=pipeline_ops_http.SetHTTPConnectionArgsOperation, + hostname=self.hostname, + server_verification_cert=op.server_verification_cert, + client_cert=op.client_cert, + sas_token=op.sas_token, + ) + self.send_op_down(worker_op) + + elif isinstance(op, pipeline_ops_iothub_http.MethodInvokeOperation): + logger.debug( + "{}({}): Translating Method Invoke Operation for HTTP.".format(self.name, op.name) + ) + query_params = "api-version={apiVersion}".format( + apiVersion=pkg_constant.IOTHUB_API_VERSION + ) + # if the target is a module. + + body = json.dumps(op.method_params) + path = http_path_iothub.get_method_invoke_path(op.target_device_id, op.target_module_id) + # Note we do not add the sas Authorization header here. Instead we add it later on in the stage above + # the transport layer, since that stage stores the updated SAS and also X509 certs if that is what is + # being used. + x_ms_edge_string = "{deviceId}/{moduleId}".format( + deviceId=self.device_id, moduleId=self.module_id + ) # these are the identifiers of the current module + user_agent = urllib.parse.quote_plus( + ProductInfo.get_iothub_user_agent() + + str(self.pipeline_root.pipeline_configuration.product_info) + ) + headers = { + "Host": self.hostname, + "Content-Type": "application/json", + "Content-Length": len(str(body)), + "x-ms-edge-moduleId": x_ms_edge_string, + "User-Agent": user_agent, + } + op_waiting_for_response = op + + def on_request_response(op, error): + logger.debug( + "{}({}): Got response for MethodInvokeOperation".format(self.name, op.name) + ) + error = map_http_error(error=error, http_op=op) + if not error: + op_waiting_for_response.method_response = json.loads( + op.response_body.decode("utf-8") + ) + op_waiting_for_response.complete(error=error) + + self.send_op_down( + pipeline_ops_http.HTTPRequestAndResponseOperation( + method="POST", + path=path, + headers=headers, + body=body, + query_params=query_params, + callback=on_request_response, + ) + ) + + elif isinstance(op, pipeline_ops_iothub_http.GetStorageInfoOperation): + logger.debug( + "{}({}): Translating Get Storage Info Operation to HTTP.".format(self.name, op.name) + ) + query_params = "api-version={apiVersion}".format( + apiVersion=pkg_constant.IOTHUB_API_VERSION + ) + path = http_path_iothub.get_storage_info_for_blob_path(self.device_id) + body = json.dumps({"blobName": op.blob_name}) + user_agent = urllib.parse.quote_plus( + ProductInfo.get_iothub_user_agent() + + str(self.pipeline_root.pipeline_configuration.product_info) + ) + headers = { + "Host": self.hostname, + "Accept": "application/json", + "Content-Type": "application/json", + "Content-Length": len(str(body)), + "User-Agent": user_agent, + } + + op_waiting_for_response = op + + def on_request_response(op, error): + logger.debug( + "{}({}): Got response for GetStorageInfoOperation".format(self.name, op.name) + ) + error = map_http_error(error=error, http_op=op) + if not error: + op_waiting_for_response.storage_info = json.loads( + op.response_body.decode("utf-8") + ) + op_waiting_for_response.complete(error=error) + + self.send_op_down( + pipeline_ops_http.HTTPRequestAndResponseOperation( + method="POST", + path=path, + headers=headers, + body=body, + query_params=query_params, + callback=on_request_response, + ) + ) + + elif isinstance(op, pipeline_ops_iothub_http.NotifyBlobUploadStatusOperation): + logger.debug( + "{}({}): Translating Get Storage Info Operation to HTTP.".format(self.name, op.name) + ) + query_params = "api-version={apiVersion}".format( + apiVersion=pkg_constant.IOTHUB_API_VERSION + ) + path = http_path_iothub.get_notify_blob_upload_status_path(self.device_id) + body = json.dumps( + { + "correlationId": op.correlation_id, + "isSuccess": op.is_success, + "statusCode": op.request_status_code, + "statusDescription": op.status_description, + } + ) + user_agent = urllib.parse.quote_plus( + ProductInfo.get_iothub_user_agent() + + str(self.pipeline_root.pipeline_configuration.product_info) + ) + + # Note we do not add the sas Authorization header here. Instead we add it later on in the stage above + # the transport layer, since that stage stores the updated SAS and also X509 certs if that is what is + # being used. + headers = { + "Host": self.hostname, + "Content-Type": "application/json; charset=utf-8", + "Content-Length": len(str(body)), + "User-Agent": user_agent, + } + op_waiting_for_response = op + + def on_request_response(op, error): + logger.debug( + "{}({}): Got response for GetStorageInfoOperation".format(self.name, op.name) + ) + error = map_http_error(error=error, http_op=op) + op_waiting_for_response.complete(error=error) + + self.send_op_down( + pipeline_ops_http.HTTPRequestAndResponseOperation( + method="POST", + path=path, + headers=headers, + body=body, + query_params=query_params, + callback=on_request_response, + ) + ) + + else: + # All other operations get passed down + self.send_op_down(op) diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub_mqtt.py b/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub_mqtt.py index 77b929163..228540032 100644 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub_mqtt.py +++ b/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub_mqtt.py @@ -13,29 +13,32 @@ from azure.iot.device.common.pipeline import ( pipeline_ops_mqtt, pipeline_events_mqtt, PipelineStage, - operation_flow, pipeline_thread, ) from azure.iot.device.iothub.models import Message, MethodRequest from . import pipeline_ops_iothub, pipeline_events_iothub, mqtt_topic_iothub from . import constant as pipeline_constant +from . import exceptions as pipeline_exceptions from azure.iot.device import constant as pkg_constant +from azure.iot.device.product_info import ProductInfo logger = logging.getLogger(__name__) -class IoTHubMQTTConverterStage(PipelineStage): +class IoTHubMQTTTranslationStage(PipelineStage): """ PipelineStage which converts other Iot and IoTHub operations into MQTT operations. This stage also converts mqtt pipeline events into Iot and IoTHub pipeline events. """ def __init__(self): - super(IoTHubMQTTConverterStage, self).__init__() + super(IoTHubMQTTTranslationStage, self).__init__() self.feature_to_topic = {} + self.device_id = None + self.module_id = None @pipeline_thread.runs_on_pipeline_thread - def _execute_op(self, op): + def _run_op(self, op): if isinstance(op, pipeline_ops_iothub.SetIoTHubConnectionArgsOperation): self.device_id = op.device_id @@ -51,14 +54,21 @@ class IoTHubMQTTConverterStage(PipelineStage): else: client_id = op.device_id + # For MQTT, the entire user agent string should be appended to the username field in the connect packet + # For example, the username may look like this without custom parameters: + # yosephsandboxhub.azure-devices.net/alpha/?api-version=2018-06-30&DeviceClientType=py-azure-iot-device%2F2.0.0-preview.12 + # The customer user agent string would simply be appended to the end of this username, in URL Encoded format. query_param_seq = [ ("api-version", pkg_constant.IOTHUB_API_VERSION), - ("DeviceClientType", pkg_constant.USER_AGENT), + ("DeviceClientType", ProductInfo.get_iothub_user_agent()), ] - username = "{hostname}/{client_id}/?{query_params}".format( + username = "{hostname}/{client_id}/?{query_params}{optional_product_info}".format( hostname=op.hostname, client_id=client_id, query_params=urllib.parse.urlencode(query_param_seq), + optional_product_info=urllib.parse.quote( + str(self.pipeline_root.pipeline_configuration.product_info) + ), ) if op.gateway_hostname: @@ -67,91 +77,64 @@ class IoTHubMQTTConverterStage(PipelineStage): hostname = op.hostname # TODO: test to make sure client_cert and sas_token travel down correctly - operation_flow.delegate_to_different_op( - stage=self, - original_op=op, - new_op=pipeline_ops_mqtt.SetMQTTConnectionArgsOperation( - client_id=client_id, - hostname=hostname, - username=username, - ca_cert=op.ca_cert, - client_cert=op.client_cert, - sas_token=op.sas_token, - ), + worker_op = op.spawn_worker_op( + worker_op_type=pipeline_ops_mqtt.SetMQTTConnectionArgsOperation, + client_id=client_id, + hostname=hostname, + username=username, + server_verification_cert=op.server_verification_cert, + client_cert=op.client_cert, + sas_token=op.sas_token, ) + self.send_op_down(worker_op) elif ( isinstance(op, pipeline_ops_base.UpdateSasTokenOperation) and self.pipeline_root.connected ): logger.debug( - "{}({}): Connected. Passing op down and reconnecting after token is updated.".format( + "{}({}): Connected. Passing op down and reauthorizing after token is updated.".format( self.name, op.name ) ) - # make a callback that can call the user's callback after the reconnect is complete - def on_reconnect_complete(reconnect_op): - if reconnect_op.error: - op.error = reconnect_op.error - logger.error( - "{}({}) reconnection failed. returning error {}".format( - self.name, op.name, op.error - ) - ) - operation_flow.complete_op(stage=self, op=op) - else: - logger.debug( - "{}({}) reconnection succeeded. returning success.".format( - self.name, op.name - ) - ) - operation_flow.complete_op(stage=self, op=op) - - # save the old user callback so we can call it later. - old_callback = op.callback - # make a callback that either fails the UpdateSasTokenOperation (if the lower level failed it), - # or issues a ReconnectOperation (if the lower level returned success for the UpdateSasTokenOperation) - def on_token_update_complete(op): - op.callback = old_callback - if op.error: + # or issues a ReauthorizeConnectionOperation (if the lower level returned success for the UpdateSasTokenOperation) + def on_token_update_complete(op, error): + if error: logger.error( "{}({}) token update failed. returning failure {}".format( - self.name, op.name, op.error + self.name, op.name, error ) ) - operation_flow.complete_op(stage=self, op=op) else: logger.debug( - "{}({}) token update succeeded. reconnecting".format(self.name, op.name) + "{}({}) token update succeeded. reauthorizing".format(self.name, op.name) ) - operation_flow.pass_op_to_next_stage( - stage=self, - op=pipeline_ops_base.ReconnectOperation(callback=on_reconnect_complete), + # Stop completion of Token Update op, and only continue upon completion of ReauthorizeConnectionOperation + op.halt_completion() + worker_op = op.spawn_worker_op( + worker_op_type=pipeline_ops_base.ReauthorizeConnectionOperation ) - logger.debug( - "{}({}): passing to next stage with updated callback.".format( - self.name, op.name - ) - ) + self.send_op_down(worker_op) # now, pass the UpdateSasTokenOperation down with our new callback. - op.callback = on_token_update_complete - operation_flow.pass_op_to_next_stage(stage=self, op=op) + op.add_callback(on_token_update_complete) + self.send_op_down(op) elif isinstance(op, pipeline_ops_iothub.SendD2CMessageOperation) or isinstance( op, pipeline_ops_iothub.SendOutputEventOperation ): # Convert SendTelementry and SendOutputEventOperation operations into MQTT Publish operations topic = mqtt_topic_iothub.encode_properties(op.message, self.telemetry_topic) - operation_flow.delegate_to_different_op( - stage=self, - original_op=op, - new_op=pipeline_ops_mqtt.MQTTPublishOperation(topic=topic, payload=op.message.data), + worker_op = op.spawn_worker_op( + worker_op_type=pipeline_ops_mqtt.MQTTPublishOperation, + topic=topic, + payload=op.message.data, ) + self.send_op_down(worker_op) elif isinstance(op, pipeline_ops_iothub.SendMethodResponseOperation): # Sending a Method Response gets translated into an MQTT Publish operation @@ -159,52 +142,48 @@ class IoTHubMQTTConverterStage(PipelineStage): op.method_response.request_id, str(op.method_response.status) ) payload = json.dumps(op.method_response.payload) - operation_flow.delegate_to_different_op( - stage=self, - original_op=op, - new_op=pipeline_ops_mqtt.MQTTPublishOperation(topic=topic, payload=payload), + worker_op = op.spawn_worker_op( + worker_op_type=pipeline_ops_mqtt.MQTTPublishOperation, topic=topic, payload=payload ) + self.send_op_down(worker_op) elif isinstance(op, pipeline_ops_base.EnableFeatureOperation): # Enabling a feature gets translated into an MQTT subscribe operation topic = self.feature_to_topic[op.feature_name] - operation_flow.delegate_to_different_op( - stage=self, - original_op=op, - new_op=pipeline_ops_mqtt.MQTTSubscribeOperation(topic=topic), + worker_op = op.spawn_worker_op( + worker_op_type=pipeline_ops_mqtt.MQTTSubscribeOperation, topic=topic ) + self.send_op_down(worker_op) elif isinstance(op, pipeline_ops_base.DisableFeatureOperation): # Disabling a feature gets turned into an MQTT unsubscribe operation topic = self.feature_to_topic[op.feature_name] - operation_flow.delegate_to_different_op( - stage=self, - original_op=op, - new_op=pipeline_ops_mqtt.MQTTUnsubscribeOperation(topic=topic), + worker_op = op.spawn_worker_op( + worker_op_type=pipeline_ops_mqtt.MQTTUnsubscribeOperation, topic=topic ) + self.send_op_down(worker_op) - elif isinstance(op, pipeline_ops_base.SendIotRequestOperation): + elif isinstance(op, pipeline_ops_base.RequestOperation): if op.request_type == pipeline_constant.TWIN: topic = mqtt_topic_iothub.get_twin_topic_for_publish( method=op.method, resource_location=op.resource_location, request_id=op.request_id, ) - operation_flow.delegate_to_different_op( - stage=self, - original_op=op, - new_op=pipeline_ops_mqtt.MQTTPublishOperation( - topic=topic, payload=op.request_body - ), + worker_op = op.spawn_worker_op( + worker_op_type=pipeline_ops_mqtt.MQTTPublishOperation, + topic=topic, + payload=op.request_body, ) + self.send_op_down(worker_op) else: - raise NotImplementedError( - "SendIotRequestOperation request_type {} not supported".format(op.request_type) + raise pipeline_exceptions.OperationError( + "RequestOperation request_type {} not supported".format(op.request_type) ) else: # All other operations get passed down - operation_flow.pass_op_to_next_stage(self, op) + super(IoTHubMQTTTranslationStage, self)._run_op(op) @pipeline_thread.runs_on_pipeline_thread def _set_topic_names(self, device_id, module_id): @@ -240,17 +219,13 @@ class IoTHubMQTTConverterStage(PipelineStage): if mqtt_topic_iothub.is_c2d_topic(topic, self.device_id): message = Message(event.payload) mqtt_topic_iothub.extract_properties_from_topic(topic, message) - operation_flow.pass_event_to_previous_stage( - self, pipeline_events_iothub.C2DMessageEvent(message) - ) + self.send_event_up(pipeline_events_iothub.C2DMessageEvent(message)) elif mqtt_topic_iothub.is_input_topic(topic, self.device_id, self.module_id): message = Message(event.payload) mqtt_topic_iothub.extract_properties_from_topic(topic, message) input_name = mqtt_topic_iothub.get_input_name_from_topic(topic) - operation_flow.pass_event_to_previous_stage( - self, pipeline_events_iothub.InputMessageEvent(input_name, message) - ) + self.send_event_up(pipeline_events_iothub.InputMessageEvent(input_name, message)) elif mqtt_topic_iothub.is_method_topic(topic): request_id = mqtt_topic_iothub.get_method_request_id_from_topic(topic) @@ -260,32 +235,28 @@ class IoTHubMQTTConverterStage(PipelineStage): name=method_name, payload=json.loads(event.payload.decode("utf-8")), ) - operation_flow.pass_event_to_previous_stage( - self, pipeline_events_iothub.MethodRequestEvent(method_received) - ) + self.send_event_up(pipeline_events_iothub.MethodRequestEvent(method_received)) elif mqtt_topic_iothub.is_twin_response_topic(topic): request_id = mqtt_topic_iothub.get_twin_request_id_from_topic(topic) status_code = int(mqtt_topic_iothub.get_twin_status_code_from_topic(topic)) - operation_flow.pass_event_to_previous_stage( - self, - pipeline_events_base.IotResponseEvent( + self.send_event_up( + pipeline_events_base.ResponseEvent( request_id=request_id, status_code=status_code, response_body=event.payload - ), + ) ) elif mqtt_topic_iothub.is_twin_desired_property_patch_topic(topic): - operation_flow.pass_event_to_previous_stage( - self, + self.send_event_up( pipeline_events_iothub.TwinDesiredPropertiesPatchEvent( patch=json.loads(event.payload.decode("utf-8")) - ), + ) ) else: - logger.debug("Uunknown topic: {} passing up to next handler".format(topic)) - operation_flow.pass_event_to_previous_stage(self, event) + logger.debug("Unknown topic: {} passing up to next handler".format(topic)) + self.send_event_up(event) else: # all other messages get passed up - operation_flow.pass_event_to_previous_stage(self, event) + super(IoTHubMQTTTranslationStage, self)._handle_pipeline_event(event) diff --git a/azure-iot-device/azure/iot/device/iothub/sync_clients.py b/azure-iot-device/azure/iot/device/iothub/sync_clients.py index e76a1bd01..0fdbfa7d5 100644 --- a/azure-iot-device/azure/iot/device/iothub/sync_clients.py +++ b/azure-iot-device/azure/iot/device/iothub/sync_clients.py @@ -15,13 +15,41 @@ from .abstract_clients import ( ) from .models import Message from .inbox_manager import InboxManager -from .sync_inbox import SyncClientInbox -from .pipeline import constant +from .sync_inbox import SyncClientInbox, InboxEmpty +from .pipeline import constant as pipeline_constant +from .pipeline import exceptions as pipeline_exceptions +from azure.iot.device import exceptions from azure.iot.device.common.evented_callback import EventedCallback +from azure.iot.device.common.callable_weak_method import CallableWeakMethod +from azure.iot.device import constant as device_constant + logger = logging.getLogger(__name__) +def handle_result(callback): + try: + return callback.wait_for_completion() + except pipeline_exceptions.ConnectionDroppedError as e: + raise exceptions.ConnectionDroppedError(message="Lost connection to IoTHub", cause=e) + except pipeline_exceptions.ConnectionFailedError as e: + raise exceptions.ConnectionFailedError(message="Could not connect to IoTHub", cause=e) + except pipeline_exceptions.UnauthorizedError as e: + raise exceptions.CredentialError(message="Credentials invalid, could not connect", cause=e) + except pipeline_exceptions.ProtocolClientError as e: + raise exceptions.ClientError(message="Error in the IoTHub client", cause=e) + except pipeline_exceptions.TlsExchangeAuthError as e: + raise exceptions.ClientError( + message="Error in the IoTHub client due to TLS exchanges.", cause=e + ) + except pipeline_exceptions.ProtocolProxyError as e: + raise exceptions.ClientError( + message="Error in the IoTHub client raised due to proxy connections.", cause=e + ) + except Exception as e: + raise exceptions.ClientError(message="Unexpected failure", cause=e) + + class GenericIoTHubClient(AbstractIoTHubClient): """A superclass representing a generic synchronous client. This class needs to be extended for specific clients. @@ -33,8 +61,10 @@ class GenericIoTHubClient(AbstractIoTHubClient): This initializer should not be called directly. Instead, use one of the 'create_from_' classmethods to instantiate - TODO: How to document kwargs? - Possible values: iothub_pipeline, edge_pipeline + :param iothub_pipeline: The IoTHubPipeline used for the client + :type iothub_pipeline: :class:`azure.iot.device.iothub.pipeline.IoTHubPipeline` + :param http_pipeline: The HTTPPipeline used for the client + :type http_pipeline: :class:`azure.iot.device.iothub.pipeline.HTTPPipeline` """ # Depending on the subclass calling this __init__, there could be different arguments, # and the super() call could call a different class, due to the different MROs @@ -42,10 +72,14 @@ class GenericIoTHubClient(AbstractIoTHubClient): # **kwargs. super(GenericIoTHubClient, self).__init__(**kwargs) self._inbox_manager = InboxManager(inbox_type=SyncClientInbox) - self._iothub_pipeline.on_connected = self._on_connected - self._iothub_pipeline.on_disconnected = self._on_disconnected - self._iothub_pipeline.on_method_request_received = self._inbox_manager.route_method_request - self._iothub_pipeline.on_twin_patch_received = self._inbox_manager.route_twin_patch + self._iothub_pipeline.on_connected = CallableWeakMethod(self, "_on_connected") + self._iothub_pipeline.on_disconnected = CallableWeakMethod(self, "_on_disconnected") + self._iothub_pipeline.on_method_request_received = CallableWeakMethod( + self._inbox_manager, "route_method_request" + ) + self._iothub_pipeline.on_twin_patch_received = CallableWeakMethod( + self._inbox_manager, "route_twin_patch" + ) def _on_connected(self): """Helper handler that is called upon an iothub pipeline connect""" @@ -65,12 +99,21 @@ class GenericIoTHubClient(AbstractIoTHubClient): This is a synchronous call, meaning that this function will not return until the connection to the service has been completely established. + + :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid + and a connection cannot be established. + :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a + connection results in failure. + :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost + during execution. + :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure + during execution. """ logger.info("Connecting to Hub...") callback = EventedCallback() self._iothub_pipeline.connect(callback=callback) - callback.wait_for_completion() + handle_result(callback) logger.info("Successfully connected to Hub") @@ -79,12 +122,15 @@ class GenericIoTHubClient(AbstractIoTHubClient): This is a synchronous call, meaning that this function will not return until the connection to the service has been completely closed. + + :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure + during execution. """ logger.info("Disconnecting from Hub...") callback = EventedCallback() self._iothub_pipeline.disconnect(callback=callback) - callback.wait_for_completion() + handle_result(callback) logger.info("Successfully disconnected from Hub") @@ -98,16 +144,30 @@ class GenericIoTHubClient(AbstractIoTHubClient): function will open the connection before sending the event. :param message: The actual message to send. Anything passed that is not an instance of the - Message class will be converted to Message object. + Message class will be converted to Message object. + :type message: :class:`azure.iot.device.Message` or str + + :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid + and a connection cannot be established. + :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a + connection results in failure. + :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost + during execution. + :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure + during execution. + :raises: ValueError if the message fails size validation. """ if not isinstance(message, Message): message = Message(message) + if message.get_size() > device_constant.TELEMETRY_MESSAGE_SIZE_LIMIT: + raise ValueError("Size of telemetry message can not exceed 256 KB.") + logger.info("Sending message to Hub...") callback = EventedCallback() self._iothub_pipeline.send_message(message, callback=callback) - callback.wait_for_completion() + handle_result(callback) logger.info("Successfully sent message to Hub") @@ -115,24 +175,24 @@ class GenericIoTHubClient(AbstractIoTHubClient): """Receive a method request via the Azure IoT Hub or Azure IoT Edge Hub. :param str method_name: Optionally provide the name of the method to receive requests for. - If this parameter is not given, all methods not already being specifically targeted by - a different request to receive_method will be received. + If this parameter is not given, all methods not already being specifically targeted by + a different request to receive_method will be received. :param bool block: Indicates if the operation should block until a request is received. - Default True. :param int timeout: Optionally provide a number of seconds until blocking times out. - :raises: InboxEmpty if timeout occurs on a blocking operation. - :raises: InboxEmpty if no request is available on a non-blocking operation. - - :returns: MethodRequest object representing the received method request. + :returns: MethodRequest object representing the received method request, or None if + no method request has been received by the end of the blocking period. """ - if not self._iothub_pipeline.feature_enabled[constant.METHODS]: - self._enable_feature(constant.METHODS) + if not self._iothub_pipeline.feature_enabled[pipeline_constant.METHODS]: + self._enable_feature(pipeline_constant.METHODS) method_inbox = self._inbox_manager.get_method_request_inbox(method_name) logger.info("Waiting for method request...") - method_request = method_inbox.get(block=block, timeout=timeout) + try: + method_request = method_inbox.get(block=block, timeout=timeout) + except InboxEmpty: + method_request = None logger.info("Received method request") return method_request @@ -146,13 +206,22 @@ class GenericIoTHubClient(AbstractIoTHubClient): function will open the connection before sending the event. :param method_response: The MethodResponse to send. - :type method_response: MethodResponse + :type method_response: :class:`azure.iot.device.MethodResponse` + + :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid + and a connection cannot be established. + :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a + connection results in failure. + :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost + during execution. + :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure + during execution. """ logger.info("Sending method response to Hub...") callback = EventedCallback() self._iothub_pipeline.send_method_response(method_response, callback=callback) - callback.wait_for_completion() + handle_result(callback) logger.info("Successfully sent method response to Hub") @@ -163,7 +232,7 @@ class GenericIoTHubClient(AbstractIoTHubClient): has been enabled. :param feature_name: The name of the feature to enable. - See azure.iot.device.common.pipeline.constant for possible values + See azure.iot.device.common.pipeline.constant for possible values """ logger.info("Enabling feature:" + feature_name + "...") @@ -180,14 +249,24 @@ class GenericIoTHubClient(AbstractIoTHubClient): This is a synchronous call, meaning that this function will not return until the twin has been retrieved from the service. - :returns: Twin object which was retrieved from the hub + :returns: Complete Twin as a JSON dict + :rtype: dict + + :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid + and a connection cannot be established. + :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a + connection results in failure. + :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost + during execution. + :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure + during execution. """ - if not self._iothub_pipeline.feature_enabled[constant.TWIN]: - self._enable_feature(constant.TWIN) + if not self._iothub_pipeline.feature_enabled[pipeline_constant.TWIN]: + self._enable_feature(pipeline_constant.TWIN) callback = EventedCallback(return_arg_name="twin") self._iothub_pipeline.get_twin(callback=callback) - twin = callback.wait_for_completion() + twin = handle_result(callback) logger.info("Successfully retrieved twin") return twin @@ -202,17 +281,26 @@ class GenericIoTHubClient(AbstractIoTHubClient): If the service returns an error on the patch operation, this function will raise the appropriate error. - :param reported_properties_patch: - :type reported_properties_patch: dict, str, int, float, bool, or None (JSON compatible values) + :param reported_properties_patch: Twin Reported Properties patch as a JSON dict + :type reported_properties_patch: dict + + :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid + and a connection cannot be established. + :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a + connection results in failure. + :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost + during execution. + :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure + during execution. """ - if not self._iothub_pipeline.feature_enabled[constant.TWIN]: - self._enable_feature(constant.TWIN) + if not self._iothub_pipeline.feature_enabled[pipeline_constant.TWIN]: + self._enable_feature(pipeline_constant.TWIN) callback = EventedCallback() self._iothub_pipeline.patch_twin_reported_properties( patch=reported_properties_patch, callback=callback ) - callback.wait_for_completion() + handle_result(callback) logger.info("Successfully patched twin") @@ -231,20 +319,21 @@ class GenericIoTHubClient(AbstractIoTHubClient): an InboxEmpty exception :param bool block: Indicates if the operation should block until a request is received. - Default True. :param int timeout: Optionally provide a number of seconds until blocking times out. - :raises: InboxEmpty if timeout occurs on a blocking operation. - :raises: InboxEmpty if no request is available on a non-blocking operation. - - :returns: desired property patch. This can be dict, str, int, float, bool, or None (JSON compatible values) + :returns: Twin Desired Properties patch as a JSON dict, or None if no patch has been + received by the end of the blocking period + :rtype: dict or None """ - if not self._iothub_pipeline.feature_enabled[constant.TWIN_PATCHES]: - self._enable_feature(constant.TWIN_PATCHES) + if not self._iothub_pipeline.feature_enabled[pipeline_constant.TWIN_PATCHES]: + self._enable_feature(pipeline_constant.TWIN_PATCHES) twin_patch_inbox = self._inbox_manager.get_twin_patch_inbox() logger.info("Waiting for twin patches...") - patch = twin_patch_inbox.get(block=block, timeout=timeout) + try: + patch = twin_patch_inbox.get(block=block, timeout=timeout) + except InboxEmpty: + return None logger.info("twin patch received") return patch @@ -255,39 +344,78 @@ class IoTHubDeviceClient(GenericIoTHubClient, AbstractIoTHubDeviceClient): Intended for usage with Python 2.7 or compatibility scenarios for Python 3.5.3+. """ - def __init__(self, iothub_pipeline): + def __init__(self, iothub_pipeline, http_pipeline): """Initializer for a IoTHubDeviceClient. This initializer should not be called directly. Instead, use one of the 'create_from_' classmethods to instantiate :param iothub_pipeline: The pipeline used to connect to the IoTHub endpoint. - :type iothub_pipeline: IoTHubPipeline + :type iothub_pipeline: :class:`azure.iot.device.iothub.pipeline.IoTHubPipeline` """ - super(IoTHubDeviceClient, self).__init__(iothub_pipeline=iothub_pipeline) - self._iothub_pipeline.on_c2d_message_received = self._inbox_manager.route_c2d_message + super(IoTHubDeviceClient, self).__init__( + iothub_pipeline=iothub_pipeline, http_pipeline=http_pipeline + ) + self._iothub_pipeline.on_c2d_message_received = CallableWeakMethod( + self._inbox_manager, "route_c2d_message" + ) def receive_message(self, block=True, timeout=None): """Receive a message that has been sent from the Azure IoT Hub. :param bool block: Indicates if the operation should block until a message is received. - Default True. :param int timeout: Optionally provide a number of seconds until blocking times out. - :raises: InboxEmpty if timeout occurs on a blocking operation. - :raises: InboxEmpty if no message is available on a non-blocking operation. - - :returns: Message that was sent from the Azure IoT Hub. + :returns: Message that was sent from the Azure IoT Hub, or None if + no method request has been received by the end of the blocking period. + :rtype: :class:`azure.iot.device.Message` or None """ - if not self._iothub_pipeline.feature_enabled[constant.C2D_MSG]: - self._enable_feature(constant.C2D_MSG) + if not self._iothub_pipeline.feature_enabled[pipeline_constant.C2D_MSG]: + self._enable_feature(pipeline_constant.C2D_MSG) c2d_inbox = self._inbox_manager.get_c2d_message_inbox() logger.info("Waiting for message from Hub...") - message = c2d_inbox.get(block=block, timeout=timeout) + try: + message = c2d_inbox.get(block=block, timeout=timeout) + except InboxEmpty: + message = None logger.info("Message received") return message + def get_storage_info_for_blob(self, blob_name): + """Sends a POST request over HTTP to an IoTHub endpoint that will return information for uploading via the Azure Storage Account linked to the IoTHub your device is connected to. + + :param str blob_name: The name in string format of the blob that will be uploaded using the storage API. This name will be used to generate the proper credentials for Storage, and needs to match what will be used with the Azure Storage SDK to perform the blob upload. + + :returns: A JSON-like (dictionary) object from IoT Hub that will contain relevant information including: correlationId, hostName, containerName, blobName, sasToken. + """ + callback = EventedCallback(return_arg_name="storage_info") + self._http_pipeline.get_storage_info_for_blob(blob_name, callback=callback) + storage_info = handle_result(callback) + logger.info("Successfully retrieved storage_info") + return storage_info + + def notify_blob_upload_status( + self, correlation_id, is_success, status_code, status_description + ): + """When the upload is complete, the device sends a POST request to the IoT Hub endpoint with information on the status of an upload to blob attempt. This is used by IoT Hub to notify listening clients. + + :param str correlation_id: Provided by IoT Hub on get_storage_info_for_blob request. + :param bool is_success: A boolean that indicates whether the file was uploaded successfully. + :param int status_code: A numeric status code that is the status for the upload of the fiel to storage. + :param str status_description: A description that corresponds to the status_code. + """ + callback = EventedCallback() + self._http_pipeline.notify_blob_upload_status( + correlation_id=correlation_id, + is_success=is_success, + status_code=status_code, + status_description=status_description, + callback=callback, + ) + handle_result(callback) + logger.info("Successfully notified blob upload status") + class IoTHubModuleClient(GenericIoTHubClient, AbstractIoTHubModuleClient): """A synchronous module client that connects to an Azure IoT Hub or Azure IoT Edge instance. @@ -295,21 +423,23 @@ class IoTHubModuleClient(GenericIoTHubClient, AbstractIoTHubModuleClient): Intended for usage with Python 2.7 or compatibility scenarios for Python 3.5.3+. """ - def __init__(self, iothub_pipeline, edge_pipeline=None): + def __init__(self, iothub_pipeline, http_pipeline): """Intializer for a IoTHubModuleClient. This initializer should not be called directly. Instead, use one of the 'create_from_' classmethods to instantiate :param iothub_pipeline: The pipeline used to connect to the IoTHub endpoint. - :type iothub_pipeline: IoTHubPipeline - :param edge_pipeline: (OPTIONAL) The pipeline used to connect to the Edge endpoint. - :type edge_pipeline: EdgePipeline + :type iothub_pipeline: :class:`azure.iot.device.iothub.pipeline.IoTHubPipeline` + :param http_pipeline: The pipeline used to connect to the IoTHub endpoint via HTTP. + :type http_pipeline: :class:`azure.iot.device.iothub.pipeline.HTTPPipeline` """ super(IoTHubModuleClient, self).__init__( - iothub_pipeline=iothub_pipeline, edge_pipeline=edge_pipeline + iothub_pipeline=iothub_pipeline, http_pipeline=http_pipeline + ) + self._iothub_pipeline.on_input_message_received = CallableWeakMethod( + self._inbox_manager, "route_input_message" ) - self._iothub_pipeline.on_input_message_received = self._inbox_manager.route_input_message def send_message_to_output(self, message, output_name): """Sends an event/message to the given module output. @@ -322,19 +452,34 @@ class IoTHubModuleClient(GenericIoTHubClient, AbstractIoTHubModuleClient): If the connection to the service has not previously been opened by a call to connect, this function will open the connection before sending the event. - :param message: message to send to the given output. Anything passed that is not an instance of the - Message class will be converted to Message object. - :param output_name: Name of the output to send the event to. + :param message: Message to send to the given output. Anything passed that is not an instance of the + Message class will be converted to Message object. + :type message: :class:`azure.iot.device.Message` or str + :param str output_name: Name of the output to send the event to. + + :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid + and a connection cannot be established. + :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a + connection results in failure. + :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost + during execution. + :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure + during execution. + :raises: ValueError if the message fails size validation. """ if not isinstance(message, Message): message = Message(message) + + if message.get_size() > device_constant.TELEMETRY_MESSAGE_SIZE_LIMIT: + raise ValueError("Size of message can not exceed 256 KB.") + message.output_name = output_name logger.info("Sending message to output:" + output_name + "...") callback = EventedCallback() self._iothub_pipeline.send_output_event(message, callback=callback) - callback.wait_for_completion() + handle_result(callback) logger.info("Successfully sent message to output: " + output_name) @@ -343,19 +488,37 @@ class IoTHubModuleClient(GenericIoTHubClient, AbstractIoTHubModuleClient): :param str input_name: The input name to receive a message on. :param bool block: Indicates if the operation should block until a message is received. - Default True. :param int timeout: Optionally provide a number of seconds until blocking times out. - :raises: InboxEmpty if timeout occurs on a blocking operation. - :raises: InboxEmpty if no message is available on a non-blocking operation. - - :returns: Message that was sent to the specified input. + :returns: Message that was sent to the specified input, or None if + no method request has been received by the end of the blocking period. """ - if not self._iothub_pipeline.feature_enabled[constant.INPUT_MSG]: - self._enable_feature(constant.INPUT_MSG) + if not self._iothub_pipeline.feature_enabled[pipeline_constant.INPUT_MSG]: + self._enable_feature(pipeline_constant.INPUT_MSG) input_inbox = self._inbox_manager.get_input_message_inbox(input_name) logger.info("Waiting for input message on: " + input_name + "...") - message = input_inbox.get(block=block, timeout=timeout) + try: + message = input_inbox.get(block=block, timeout=timeout) + except InboxEmpty: + message = None logger.info("Input message received on: " + input_name) return message + + def invoke_method(self, method_params, device_id, module_id=None): + """Invoke a method from your client onto a device or module client, and receive the response to the method call. + + :param dict method_params: Should contain a method_name, payload, connect_timeout_in_seconds, response_timeout_in_seconds. + :param str device_id: Device ID of the target device where the method will be invoked. + :param str module_id: Module ID of the target module where the method will be invoked. (Optional) + + :returns: method_result should contain a status, and a payload + :rtype: dict + """ + callback = EventedCallback(return_arg_name="invoke_method_response") + self._http_pipeline.invoke_method( + device_id, method_params, callback=callback, module_id=module_id + ) + invoke_method_response = handle_result(callback) + logger.info("Successfully invoked method") + return invoke_method_response diff --git a/azure-iot-device/azure/iot/device/product_info.py b/azure-iot-device/azure/iot/device/product_info.py new file mode 100644 index 000000000..c8b735b12 --- /dev/null +++ b/azure-iot-device/azure/iot/device/product_info.py @@ -0,0 +1,50 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import platform +from azure.iot.device.constant import VERSION, IOTHUB_IDENTIFIER, PROVISIONING_IDENTIFIER + +python_runtime = platform.python_version() +os_type = platform.system() +os_release = platform.version() +architecture = platform.machine() + + +class ProductInfo(object): + """ + A class for creating product identifiers or agent strings for IotHub as well as Provisioning. + """ + + @staticmethod + def _get_common_user_agent(): + return "({python_runtime};{os_type} {os_release};{architecture})".format( + python_runtime=python_runtime, + os_type=os_type, + os_release=os_release, + architecture=architecture, + ) + + @staticmethod + def get_iothub_user_agent(): + """ + Create the user agent for IotHub + """ + return "{iothub_iden}/{version}{common}".format( + iothub_iden=IOTHUB_IDENTIFIER, + version=VERSION, + common=ProductInfo._get_common_user_agent(), + ) + + @staticmethod + def get_provisioning_user_agent(): + """ + Create the user agent for Provisioning + """ + return "{provisioning_iden}/{version}{common}".format( + provisioning_iden=PROVISIONING_IDENTIFIER, + version=VERSION, + common=ProductInfo._get_common_user_agent(), + ) diff --git a/azure-iot-device/azure/iot/device/provisioning/abstract_provisioning_device_client.py b/azure-iot-device/azure/iot/device/provisioning/abstract_provisioning_device_client.py index 69cf0d86b..576570499 100644 --- a/azure-iot-device/azure/iot/device/provisioning/abstract_provisioning_device_client.py +++ b/azure-iot-device/azure/iot/device/provisioning/abstract_provisioning_device_client.py @@ -11,13 +11,22 @@ Device Provisioning Service. import abc import six import logging -from .security.sk_security_client import SymmetricKeySecurityClient -from .security.x509_security_client import X509SecurityClient -from azure.iot.device.provisioning.pipeline.provisioning_pipeline import ProvisioningPipeline +from azure.iot.device.provisioning import pipeline, security logger = logging.getLogger(__name__) +def _validate_kwargs(**kwargs): + """Helper function to validate user provided kwargs. + Raises TypeError if an invalid option has been provided""" + # TODO: add support for server_verification_cert + valid_kwargs = ["websockets", "cipher"] + + for kwarg in kwargs: + if kwarg not in valid_kwargs: + raise TypeError("Got an unexpected keyword argument '{}'".format(kwarg)) + + @six.add_metaclass(abc.ABCMeta) class AbstractProvisioningDeviceClient(object): """ @@ -27,80 +36,110 @@ class AbstractProvisioningDeviceClient(object): def __init__(self, provisioning_pipeline): """ Initializes the provisioning client. + + NOTE: This initializer should not be called directly. + Instead, the class methods that start with `create_from_` should be used to create a + client object. + :param provisioning_pipeline: Instance of the provisioning pipeline object. + :type provisioning_pipeline: :class:`azure.iot.device.provisioning.pipeline.ProvisioningPipeline` """ self._provisioning_pipeline = provisioning_pipeline + self._provisioning_payload = None @classmethod def create_from_symmetric_key( - cls, provisioning_host, registration_id, id_scope, symmetric_key, protocol_choice=None + cls, provisioning_host, registration_id, id_scope, symmetric_key, **kwargs ): """ Create a client which can be used to run the registration of a device with provisioning service using Symmetric Key authentication. - :param provisioning_host: Host running the Device Provisioning Service. Can be found in the Azure portal in the - Overview tab as the string Global device endpoint - :param registration_id: The registration ID is used to uniquely identify a device in the Device Provisioning Service. - The registration ID is alphanumeric, lowercase string and may contain hyphens. - :param id_scope: The ID scope is used to uniquely identify the specific provisioning service the device will - register through. The ID scope is assigned to a Device Provisioning Service when it is created by the user and - is generated by the service and is immutable, guaranteeing uniqueness. - :param symmetric_key: The key which will be used to create the shared access signature token to authenticate - the device with the Device Provisioning Service. By default, the Device Provisioning Service creates - new symmetric keys with a default length of 32 bytes when new enrollments are saved with the Auto-generate keys - option enabled. Users can provide their own symmetric keys for enrollments by disabling this option within - 16 bytes and 64 bytes and in valid Base64 format. - :param protocol_choice: The choice for the protocol to be used. This is optional and will default to protocol MQTT currently. - :return: A ProvisioningDeviceClient which can register via Symmetric Key. + + :param str provisioning_host: Host running the Device Provisioning Service. + Can be found in the Azure portal in the Overview tab as the string Global device endpoint. + :param str registration_id: The registration ID used to uniquely identify a device in the + Device Provisioning Service. The registration ID is alphanumeric, lowercase string + and may contain hyphens. + :param str id_scope: The ID scope used to uniquely identify the specific provisioning + service the device will register through. The ID scope is assigned to a + Device Provisioning Service when it is created by the user and is generated by the + service and is immutable, guaranteeing uniqueness. + :param str symmetric_key: The key which will be used to create the shared access signature + token to authenticate the device with the Device Provisioning Service. By default, + the Device Provisioning Service creates new symmetric keys with a default length of + 32 bytes when new enrollments are saved with the Auto-generate keys option enabled. + Users can provide their own symmetric keys for enrollments by disabling this option + within 16 bytes and 64 bytes and in valid Base64 format. + + :param bool websockets: Configuration Option. Default is False. Set to true if using MQTT + over websockets. + :param cipher: Configuration Option. Cipher suite(s) for TLS/SSL, as a string in + "OpenSSL cipher list format" or as a list of cipher suite strings. + :type cipher: str or list(str) + + :raises: TypeError if given an unrecognized parameter. + + :returns: A ProvisioningDeviceClient instance which can register via Symmetric Key. """ - if protocol_choice is not None: - protocol_name = protocol_choice.lower() - else: - protocol_name = "mqtt" - if protocol_name == "mqtt": - security_client = SymmetricKeySecurityClient( - provisioning_host, registration_id, id_scope, symmetric_key - ) - mqtt_provisioning_pipeline = ProvisioningPipeline(security_client) - return cls(mqtt_provisioning_pipeline) - else: - raise NotImplementedError( - "A symmetric key can only create symmetric key security client which is compatible " - "only with MQTT protocol.Any other protocol has not been implemented." - ) + _validate_kwargs(**kwargs) + + security_client = security.SymmetricKeySecurityClient( + provisioning_host=provisioning_host, + registration_id=registration_id, + id_scope=id_scope, + symmetric_key=symmetric_key, + ) + pipeline_configuration = pipeline.ProvisioningPipelineConfig(**kwargs) + mqtt_provisioning_pipeline = pipeline.ProvisioningPipeline( + security_client, pipeline_configuration + ) + return cls(mqtt_provisioning_pipeline) @classmethod def create_from_x509_certificate( - cls, provisioning_host, registration_id, id_scope, x509, protocol_choice=None + cls, provisioning_host, registration_id, id_scope, x509, **kwargs ): """ - Create a client which can be used to run the registration of a device with provisioning service - using X509 certificate authentication. - :param provisioning_host: Host running the Device Provisioning Service. Can be found in the Azure portal in the - Overview tab as the string Global device endpoint - :param registration_id: The registration ID is used to uniquely identify a device in the Device Provisioning Service. - The registration ID is alphanumeric, lowercase string and may contain hyphens. - :param id_scope: The ID scope is used to uniquely identify the specific provisioning service the device will - register through. The ID scope is assigned to a Device Provisioning Service when it is created by the user and - is generated by the service and is immutable, guaranteeing uniqueness. - :param x509: The x509 certificate, To use the certificate the enrollment object needs to contain cert (either the root certificate or one of the intermediate CA certificates). - If the cert comes from a CER file, it needs to be base64 encoded. - :param protocol_choice: The choice for the protocol to be used. This is optional and will default to protocol MQTT currently. - :return: A ProvisioningDeviceClient which can register via Symmetric Key. + Create a client which can be used to run the registration of a device with + provisioning service using X509 certificate authentication. + + :param str provisioning_host: Host running the Device Provisioning Service. Can be found in + the Azure portal in the Overview tab as the string Global device endpoint. + :param str registration_id: The registration ID used to uniquely identify a device in the + Device Provisioning Service. The registration ID is alphanumeric, lowercase string + and may contain hyphens. + :param str id_scope: The ID scope is used to uniquely identify the specific + provisioning service the device will register through. The ID scope is assigned to a + Device Provisioning Service when it is created by the user and is generated by the + service and is immutable, guaranteeing uniqueness. + :param x509: The x509 certificate, To use the certificate the enrollment object needs to + contain cert (either the root certificate or one of the intermediate CA certificates). + If the cert comes from a CER file, it needs to be base64 encoded. + :type x509: :class:`azure.iot.device.X509` + + :param bool websockets: Configuration Option. Default is False. Set to true if using MQTT + over websockets. + :param cipher: Configuration Option. Cipher suite(s) for TLS/SSL, as a string in + "OpenSSL cipher list format" or as a list of cipher suite strings. + :type cipher: str or list(str) + + :raises: TypeError if given an unrecognized parameter. + + :returns: A ProvisioningDeviceClient which can register via Symmetric Key. """ - if protocol_choice is None: - protocol_name = "mqtt" - else: - protocol_name = protocol_choice.lower() - if protocol_name == "mqtt": - security_client = X509SecurityClient(provisioning_host, registration_id, id_scope, x509) - mqtt_provisioning_pipeline = ProvisioningPipeline(security_client) - return cls(mqtt_provisioning_pipeline) - else: - raise NotImplementedError( - "A x509 certificate can only create x509 security client which is compatible only " - "with MQTT protocol.Any other protocol has not been implemented." - ) + _validate_kwargs(**kwargs) + + security_client = security.X509SecurityClient( + provisioning_host=provisioning_host, + registration_id=registration_id, + id_scope=id_scope, + x509=x509, + ) + pipeline_configuration = pipeline.ProvisioningPipelineConfig(**kwargs) + mqtt_provisioning_pipeline = pipeline.ProvisioningPipeline( + security_client, pipeline_configuration + ) + return cls(mqtt_provisioning_pipeline) @abc.abstractmethod def register(self): @@ -109,12 +148,19 @@ class AbstractProvisioningDeviceClient(object): """ pass - @abc.abstractmethod - def cancel(self): + @property + def provisioning_payload(self): + return self._provisioning_payload + + @provisioning_payload.setter + def provisioning_payload(self, provisioning_payload): """ - Cancel an in progress registration of the device with the Device Provisioning Service. + Set the payload that will form the request payload in a registration request. + + :param provisioning_payload: The payload that can be supplied by the user. + :type provisioning_payload: This can be an object or dictionary or a string or an integer. """ - pass + self._provisioning_payload = provisioning_payload def log_on_register_complete(result=None): diff --git a/azure-iot-device/azure/iot/device/provisioning/aio/async_provisioning_device_client.py b/azure-iot-device/azure/iot/device/provisioning/aio/async_provisioning_device_client.py index 748c4b830..d895b6bee 100644 --- a/azure-iot-device/azure/iot/device/provisioning/aio/async_provisioning_device_client.py +++ b/azure-iot-device/azure/iot/device/provisioning/aio/async_provisioning_device_client.py @@ -3,8 +3,10 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -"""This module contains user-facing asynchronous clients for the -Azure Provisioning Device SDK for Python. +""" +This module contains user-facing asynchronous Provisioning Device Client for Azure Provisioning +Device SDK. This client uses Symmetric Key and X509 authentication to register devices with an +IoT Hub via the Device Provisioning Service. """ import logging @@ -15,55 +17,77 @@ from azure.iot.device.provisioning.abstract_provisioning_device_client import ( from azure.iot.device.provisioning.abstract_provisioning_device_client import ( log_on_register_complete, ) -from azure.iot.device.provisioning.internal.polling_machine import PollingMachine +from azure.iot.device.provisioning.pipeline import exceptions as pipeline_exceptions +from azure.iot.device import exceptions +from azure.iot.device.provisioning.pipeline import constant as dps_constant logger = logging.getLogger(__name__) +async def handle_result(callback): + try: + return await callback.completion() + except pipeline_exceptions.ConnectionDroppedError as e: + raise exceptions.ConnectionDroppedError(message="Lost connection to IoTHub", cause=e) + except pipeline_exceptions.ConnectionFailedError as e: + raise exceptions.ConnectionFailedError(message="Could not connect to IoTHub", cause=e) + except pipeline_exceptions.UnauthorizedError as e: + raise exceptions.CredentialError(message="Credentials invalid, could not connect", cause=e) + except pipeline_exceptions.ProtocolClientError as e: + raise exceptions.ClientError(message="Error in the IoTHub client", cause=e) + except Exception as e: + raise exceptions.ClientError(message="Unexpected failure", cause=e) + + class ProvisioningDeviceClient(AbstractProvisioningDeviceClient): """ Client which can be used to run the registration of a device with provisioning service - using Symmetric Key authentication. + using Symmetric Key or X509 authentication. """ - def __init__(self, provisioning_pipeline): - """ - Initializer for the Provisioning Client. - NOTE : This initializer should not be called directly. - Instead, the class method `create_from_security_client` should be used to create a client object. - :param provisioning_pipeline: The protocol pipeline for provisioning. As of now this only supports MQTT. - """ - super(ProvisioningDeviceClient, self).__init__(provisioning_pipeline) - self._polling_machine = PollingMachine(provisioning_pipeline) - async def register(self): """ Register the device with the provisioning service. + Before returning the client will also disconnect from the provisioning service. - If a registration attempt is made while a previous registration is in progress it may throw an error. + If a registration attempt is made while a previous registration is in progress it may + throw an error. + + :returns: RegistrationResult indicating the result of the registration. + :rtype: :class:`azure.iot.device.RegistrationResult` + + :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid + and a connection cannot be established. + :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a + connection results in failure. + :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost + during execution. + :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure + during execution. + """ logger.info("Registering with Provisioning Service...") - register_async = async_adapter.emulate_async(self._polling_machine.register) - callback = async_adapter.AwaitableCallback(return_arg_name="result") - await register_async(callback=callback) - result = await callback.completion() + if not self._provisioning_pipeline.responses_enabled[dps_constant.REGISTER]: + await self._enable_responses() + + register_async = async_adapter.emulate_async(self._provisioning_pipeline.register) + + register_complete = async_adapter.AwaitableCallback(return_arg_name="result") + await register_async(payload=self._provisioning_payload, callback=register_complete) + result = await handle_result(register_complete) log_on_register_complete(result) return result - async def cancel(self): + async def _enable_responses(self): + """Enable to receive responses from Device Provisioning Service. """ - Before returning the client will also disconnect from the provisioning service. + logger.info("Enabling reception of response from Device Provisioning Service...") + subscribe_async = async_adapter.emulate_async(self._provisioning_pipeline.enable_responses) - In case there is no registration in process it will throw an error as there is - no registration process to cancel. - """ - logger.info("Disconnecting from Provisioning Service...") - cancel_async = async_adapter.emulate_async(self._polling_machine.cancel) + subscription_complete = async_adapter.AwaitableCallback() + await subscribe_async(callback=subscription_complete) + await handle_result(subscription_complete) - callback = async_adapter.AwaitableCallback() - await cancel_async(callback=callback) - await callback.completion() - - logger.info("Successfully cancelled the current registration process") + logger.info("Successfully subscribed to Device Provisioning Service to receive responses") diff --git a/azure-iot-device/azure/iot/device/provisioning/internal/__init__.py b/azure-iot-device/azure/iot/device/provisioning/internal/__init__.py deleted file mode 100644 index ba8577988..000000000 --- a/azure-iot-device/azure/iot/device/provisioning/internal/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -"""Azure Provisioning Device Internal - -This package provides internal classes for use within the Azure Provisioning Device SDK. -""" diff --git a/azure-iot-device/azure/iot/device/provisioning/internal/polling_machine.py b/azure-iot-device/azure/iot/device/provisioning/internal/polling_machine.py deleted file mode 100644 index 688f9d3f8..000000000 --- a/azure-iot-device/azure/iot/device/provisioning/internal/polling_machine.py +++ /dev/null @@ -1,450 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import logging -import uuid -import json -import traceback -from threading import Timer -from transitions import Machine -from azure.iot.device.provisioning.pipeline import constant -import six.moves.urllib as urllib -from .request_response_provider import RequestResponseProvider -from azure.iot.device.provisioning.models.registration_result import ( - RegistrationResult, - RegistrationState, -) -from .registration_query_status_result import RegistrationQueryStatusResult - - -logger = logging.getLogger(__name__) - -POS_STATUS_CODE_IN_TOPIC = 3 -POS_QUERY_PARAM_PORTION = 2 - - -class PollingMachine(object): - """ - Class that is responsible for sending the initial registration request and polling the - registration process for constant updates. - """ - - def __init__(self, provisioning_pipeline): - """ - :param provisioning_pipeline: The pipeline for provisioning. - """ - self._polling_timer = None - self._query_timer = None - - self._register_callback = None - self._cancel_callback = None - - self._registration_error = None - self._registration_result = None - - self._operations = {} - - self._request_response_provider = RequestResponseProvider(provisioning_pipeline) - - states = [ - "disconnected", - "initializing", - "registering", - "waiting_to_poll", - "polling", - "completed", - "error", - "cancelling", - ] - - transitions = [ - { - "trigger": "_trig_register", - "source": "disconnected", - "before": "_initialize_register", - "dest": "initializing", - }, - { - "trigger": "_trig_register", - "source": "error", - "before": "_initialize_register", - "dest": "initializing", - }, - {"trigger": "_trig_register", "source": "registering", "dest": None}, - { - "trigger": "_trig_send_register_request", - "source": "initializing", - "before": "_send_register_request", - "dest": "registering", - }, - { - "trigger": "_trig_send_register_request", - "source": "waiting_to_poll", - "before": "_send_register_request", - "dest": "registering", - }, - { - "trigger": "_trig_wait", - "source": "registering", - "dest": "waiting_to_poll", - "after": "_wait_for_interval", - }, - {"trigger": "_trig_wait", "source": "cancelling", "dest": None}, - { - "trigger": "_trig_wait", - "source": "polling", - "dest": "waiting_to_poll", - "after": "_wait_for_interval", - }, - { - "trigger": "_trig_poll", - "source": "waiting_to_poll", - "dest": "polling", - "after": "_query_operation_status", - }, - {"trigger": "_trig_poll", "source": "cancelling", "dest": None}, - { - "trigger": "_trig_complete", - "source": ["registering", "waiting_to_poll", "polling"], - "dest": "completed", - "after": "_call_complete", - }, - { - "trigger": "_trig_error", - "source": ["registering", "waiting_to_poll", "polling"], - "dest": "error", - "after": "_call_error", - }, - {"trigger": "_trig_error", "source": "cancelling", "dest": None}, - { - "trigger": "_trig_cancel", - "source": ["disconnected", "completed"], - "dest": None, - "after": "_inform_no_process", - }, - { - "trigger": "_trig_cancel", - "source": ["initializing", "registering", "waiting_to_poll", "polling"], - "dest": "cancelling", - "after": "_call_cancel", - }, - ] - - def _on_transition_complete(event_data): - if not event_data.transition: - dest = "[no transition]" - else: - dest = event_data.transition.dest - logger.debug( - "Transition complete. Trigger={}, Src={}, Dest={}, result={}, error{}".format( - event_data.event.name, - event_data.transition.source, - dest, - str(event_data.result), - str(event_data.error), - ) - ) - - self._state_machine = Machine( - model=self, - states=states, - transitions=transitions, - initial="disconnected", - send_event=True, # Use event_data structures to pass transition arguments - finalize_event=_on_transition_complete, - queued=True, - ) - - def register(self, callback=None): - """ - Register the device with the provisioning service. - :param:Callback to be called upon finishing the registration process - """ - logger.info("register called from polling machine") - self._register_callback = callback - self._trig_register() - - def cancel(self, callback=None): - """ - Cancels the current registration process of the device. - :param:Callback to be called upon finishing the cancellation process - """ - logger.info("cancel called from polling machine") - self._cancel_callback = callback - self._trig_cancel() - - def _initialize_register(self, event_data): - logger.info("Initializing the registration process.") - self._request_response_provider.enable_responses(callback=self._on_subscribe_completed) - - def _send_register_request(self, event_data): - """ - Send the registration request. - """ - logger.info("Sending registration request") - self._set_query_timer() - - request_id = str(uuid.uuid4()) - - self._operations[request_id] = constant.PUBLISH_TOPIC_REGISTRATION.format(request_id) - self._request_response_provider.send_request( - request_id=request_id, - request_payload=" ", - operation_id=None, - callback_on_response=self._on_register_response_received, - ) - - def _query_operation_status(self, event_data): - """ - Poll the service for operation status. - """ - logger.info("Querying operation status from polling machine") - self._set_query_timer() - - request_id = str(uuid.uuid4()) - result = event_data.args[0].args[0] - - operation_id = result.operation_id - self._operations[request_id] = constant.PUBLISH_TOPIC_QUERYING.format( - request_id, operation_id - ) - self._request_response_provider.send_request( - request_id=request_id, - request_payload=" ", - operation_id=operation_id, - callback_on_response=self._on_query_response_received, - ) - - def _on_register_response_received(self, request_id, status_code, key_values_dict, response): - """ - The function to call in case of a response from a registration request. - :param request_id: The id of the original register request. - :param status_code: The status code in the response. - :param key_values_dict: The dictionary containing the query parameters of the returned topic. - :param response: The complete response from the service. - """ - self._query_timer.cancel() - - retry_after = ( - None if "retry-after" not in key_values_dict else str(key_values_dict["retry-after"][0]) - ) - intermediate_registration_result = RegistrationQueryStatusResult(request_id, retry_after) - - if int(status_code, 10) >= 429: - del self._operations[request_id] - self._trig_wait(intermediate_registration_result) - elif int(status_code, 10) >= 300: # pure failure - self._registration_error = ValueError("Incoming message failure") - self._trig_error() - else: # successful case, transition into complete or poll status - self._process_successful_response(request_id, retry_after, response) - - def _on_query_response_received(self, request_id, status_code, key_values_dict, response): - """ - The function to call in case of a response from a polling/query request. - :param request_id: The id of the original query request. - :param status_code: The status code in the response. - :param key_values_dict: The dictionary containing the query parameters of the returned topic. - :param response: The complete response from the service. - """ - self._query_timer.cancel() - self._polling_timer.cancel() - - retry_after = ( - None if "retry-after" not in key_values_dict else str(key_values_dict["retry-after"][0]) - ) - intermediate_registration_result = RegistrationQueryStatusResult(request_id, retry_after) - - if int(status_code, 10) >= 429: - if request_id in self._operations: - publish_query_topic = self._operations[request_id] - del self._operations[request_id] - topic_parts = publish_query_topic.split("$") - key_values_publish_topic = urllib.parse.parse_qs( - topic_parts[POS_QUERY_PARAM_PORTION] - ) - operation_id = key_values_publish_topic["operationId"][0] - intermediate_registration_result.operation_id = operation_id - self._trig_wait(intermediate_registration_result) - else: - self._registration_error = ValueError("This request was never sent") - self._trig_error() - elif int(status_code, 10) >= 300: # pure failure - self._registration_error = ValueError("Incoming message failure") - self._trig_error() - else: # successful status code case, transition into complete or another poll status - self._process_successful_response(request_id, retry_after, response) - - def _process_successful_response(self, request_id, retry_after, response): - """ - Fucntion to call in case of 200 response from the service - :param request_id: The request id - :param retry_after: The time after which to try again. - :param response: The complete response - """ - del self._operations[request_id] - successful_result = self._decode_json_response(request_id, retry_after, response) - if successful_result.status == "assigning": - self._trig_wait(successful_result) - elif successful_result.status == "assigned" or successful_result.status == "failed": - complete_registration_result = self._decode_complete_json_response( - successful_result, response - ) - self._registration_result = complete_registration_result - self._trig_complete() - else: - self._registration_error = ValueError("Other types of failure have occurred.", response) - self._trig_error() - - def _inform_no_process(self, event_data): - raise RuntimeError("There is no registration process to cancel.") - - def _call_cancel(self, event_data): - """ - Completes the cancellation process - """ - logger.info("Cancel called from polling machine") - self._clear_timers() - self._request_response_provider.disconnect(callback=self._on_disconnect_completed_cancel) - - def _call_error(self, event_data): - logger.info("Failed register from polling machine") - - self._clear_timers() - self._request_response_provider.disconnect(callback=self._on_disconnect_completed_error) - - def _call_complete(self, event_data): - logger.info("Complete register from polling machine") - self._clear_timers() - self._request_response_provider.disconnect(callback=self._on_disconnect_completed_register) - - def _clear_timers(self): - """ - Clears all the timers and disconnects from the service - """ - if self._query_timer is not None: - self._query_timer.cancel() - if self._polling_timer is not None: - self._polling_timer.cancel() - - def _set_query_timer(self): - def time_up_query(): - logger.error("Time is up for query timer") - self._query_timer.cancel() - # TimeoutError not defined in python 2 - self._registration_error = ValueError("Time is up for query timer") - self._trig_error() - - self._query_timer = Timer(constant.DEFAULT_TIMEOUT_INTERVAL, time_up_query) - self._query_timer.start() - - def _wait_for_interval(self, event_data): - def time_up_polling(): - self._polling_timer.cancel() - logger.debug("Done waiting for polling interval of {} secs".format(polling_interval)) - if result.operation_id is None: - self._trig_send_register_request(event_data) - else: - self._trig_poll(event_data) - - result = event_data.args[0] - polling_interval = ( - constant.DEFAULT_POLLING_INTERVAL - if result.retry_after is None - else int(result.retry_after, 10) - ) - - self._polling_timer = Timer(polling_interval, time_up_polling) - logger.debug("Waiting for " + str(constant.DEFAULT_POLLING_INTERVAL) + " secs") - self._polling_timer.start() # This is waiting for that polling interval - - def _decode_complete_json_response(self, query_result, response): - """ - Decodes the complete json response for details regarding the registration process. - :param query_result: The partially formed result. - :param response: The complete response from the service - """ - decoded_result = json.loads(response) - - decoded_state = ( - None - if "registrationState" not in decoded_result - else decoded_result["registrationState"] - ) - registration_state = None - if decoded_state is not None: - # Everything needs to be converted to string explicitly for python 2 - # as everything is by default a unicode character - registration_state = RegistrationState( - None if "deviceId" not in decoded_state else str(decoded_state["deviceId"]), - None if "assignedHub" not in decoded_state else str(decoded_state["assignedHub"]), - None if "substatus" not in decoded_state else str(decoded_state["substatus"]), - None - if "createdDateTimeUtc" not in decoded_state - else str(decoded_state["createdDateTimeUtc"]), - None - if "lastUpdatedDateTimeUtc" not in decoded_state - else str(decoded_state["lastUpdatedDateTimeUtc"]), - None if "etag" not in decoded_state else str(decoded_state["etag"]), - ) - - registration_result = RegistrationResult( - request_id=query_result.request_id, - operation_id=query_result.operation_id, - status=query_result.status, - registration_state=registration_state, - ) - return registration_result - - def _decode_json_response(self, request_id, retry_after, response): - """ - Decodes the json response for operation id and status - :param request_id: The request id. - :param retry_after: The time in secs after which to retry. - :param response: The complete response from the service. - """ - decoded_result = json.loads(response) - - operation_id = ( - None if "operationId" not in decoded_result else str(decoded_result["operationId"]) - ) - status = None if "status" not in decoded_result else str(decoded_result["status"]) - - return RegistrationQueryStatusResult(request_id, retry_after, operation_id, status) - - def _on_disconnect_completed_error(self): - logger.info("on_disconnect_completed for Device Provisioning Service") - callback = self._register_callback - if callback: - self._register_callback = None - try: - callback(error=self._registration_error) - except Exception: - logger.error("Unexpected error calling callback supplied to register") - logger.error(traceback.format_exc()) - - def _on_disconnect_completed_cancel(self): - logger.info("on_disconnect_completed after cancelling current Device Provisioning Service") - callback = self._cancel_callback - - if callback: - self._cancel_callback = None - callback() - - def _on_disconnect_completed_register(self): - logger.info("on_disconnect_completed after registration to Device Provisioning Service") - callback = self._register_callback - - if callback: - self._register_callback = None - try: - callback(result=self._registration_result) - except Exception: - logger.error("Unexpected error calling callback supplied to register") - logger.error(traceback.format_exc()) - - def _on_subscribe_completed(self): - logger.debug("on_subscribe_completed for Device Provisioning Service") - self._trig_send_register_request() diff --git a/azure-iot-device/azure/iot/device/provisioning/internal/registration_query_status_result.py b/azure-iot-device/azure/iot/device/provisioning/internal/registration_query_status_result.py deleted file mode 100644 index 55215fd11..000000000 --- a/azure-iot-device/azure/iot/device/provisioning/internal/registration_query_status_result.py +++ /dev/null @@ -1,58 +0,0 @@ -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - - -class RegistrationQueryStatusResult(object): - """ - The result of any registration attempt - :ivar:request_id: The request id to which the response is being obtained - :ivar:operation_id: The id of the operation as returned by the registration request. - :ivar status: The status of the registration process as returned by provisioning service. - Values can be "unassigned", "assigning", "assigned", "failed", "disabled" - from the provisioning service. - """ - - def __init__(self, request_id=None, retry_after=None, operation_id=None, status=None): - """ - :param request_id: The request id to which the response is being obtained - :param retry_after : Number of secs after which to retry again. - :param operation_id: The id of the operation as returned by the initial registration request. - :param status: The status of the registration process. - Values can be "unassigned", "assigning", "assigned", "failed", "disabled" - from the provisioning service. - """ - self._request_id = request_id - self._operation_id = operation_id - self._status = status - self._retry_after = retry_after - - @property - def request_id(self): - return self._request_id - - @property - def retry_after(self): - return self._retry_after - - @retry_after.setter - def retry_after(self, val): - self._retry_after = val - - @property - def operation_id(self): - return self._operation_id - - @operation_id.setter - def operation_id(self, val): - self._operation_id = val - - @property - def status(self): - return self._status - - @status.setter - def status(self, val): - self._status = val diff --git a/azure-iot-device/azure/iot/device/provisioning/internal/request_response_provider.py b/azure-iot-device/azure/iot/device/provisioning/internal/request_response_provider.py deleted file mode 100644 index 05f4964ca..000000000 --- a/azure-iot-device/azure/iot/device/provisioning/internal/request_response_provider.py +++ /dev/null @@ -1,101 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import logging - -logger = logging.getLogger(__name__) - -POS_STATUS_CODE_IN_TOPIC = 3 -POS_URL_PORTION = 1 -POS_QUERY_PARAM_PORTION = 2 - - -class RequestResponseProvider(object): - """ - Class that processes requests sent from device and responses received at device. - """ - - def __init__(self, provisioning_pipeline): - - self._provisioning_pipeline = provisioning_pipeline - - self._provisioning_pipeline.on_message_received = self._receive_response - - self._pending_requests = {} - - def send_request( - self, request_id, request_payload, operation_id=None, callback_on_response=None - ): - """ - Sends a request - :param request_id: Id of the request - :param request_payload: The payload of the request. - :param operation_id: A id of the operation in case it is an ongoing process. - :param callback_on_response: callback which is called when response comes back for this request. - """ - self._pending_requests[request_id] = callback_on_response - self._provisioning_pipeline.send_request( - request_id=request_id, - request_payload=request_payload, - operation_id=operation_id, - callback=self._on_publish_completed, - ) - - def connect(self, callback=None): - if callback is None: - callback = self._on_connection_state_change - self._provisioning_pipeline.connect(callback=callback) - - def disconnect(self, callback=None): - if callback is None: - callback = self._on_connection_state_change - self._provisioning_pipeline.disconnect(callback=callback) - - def enable_responses(self, callback=None): - if callback is None: - callback = self._on_subscribe_completed - self._provisioning_pipeline.enable_responses(callback=callback) - - def disable_responses(self, callback=None): - if callback is None: - callback = self._on_unsubscribe_completed - self._provisioning_pipeline.disable_responses(callback=callback) - - def _receive_response(self, request_id, status_code, key_value_dict, response_payload): - """ - Handler that processes the response from the service. - :param request_id: The id of the request which is being responded to. - :param status_code: The status code inside the response - :param key_value_dict: A dictionary of keys mapped to a list of values extracted from the topic of the response. - :param response_payload: String payload of the message received. - :return: - """ - # """ Sample topic and payload - # $dps/registrations/res/200/?$rid=28c32371-608c-4390-8da7-c712353c1c3b - # {"operationId":"4.550cb20c3349a409.390d2957-7b58-4701-b4f9-7fe848348f4a","status":"assigning"} - # """ - logger.debug("Received response {}:".format(response_payload)) - - if request_id in self._pending_requests: - callback = self._pending_requests[request_id] - # Only send the status code and the extracted topic - callback(request_id, status_code, key_value_dict, response_payload) - del self._pending_requests[request_id] - - # TODO : What happens when request_id if not there ? trigger error ? - - def _on_connection_state_change(self, new_state): - """Handler to be called by the pipeline upon a connection state change.""" - logger.info("Connection State - {}".format(new_state)) - - def _on_publish_completed(self): - logger.debug("publish completed for request response provider") - - def _on_subscribe_completed(self): - logger.debug("subscribe completed for request response provider") - - def _on_unsubscribe_completed(self): - logger.debug("on_unsubscribe_completed for request response provider") diff --git a/azure-iot-device/azure/iot/device/provisioning/models/registration_result.py b/azure-iot-device/azure/iot/device/provisioning/models/registration_result.py index 5d68661c5..560d76720 100644 --- a/azure-iot-device/azure/iot/device/provisioning/models/registration_result.py +++ b/azure-iot-device/azure/iot/device/provisioning/models/registration_result.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +import json class RegistrationResult(object): @@ -16,24 +17,18 @@ class RegistrationResult(object): from the provisioning service. """ - def __init__(self, request_id, operation_id, status, registration_state=None): + def __init__(self, operation_id, status, registration_state=None): """ - :param request_id: The request id to which the response is being obtained :param operation_id: The id of the operation as returned by the initial registration request. :param status: The status of the registration process. Values can be "unassigned", "assigning", "assigned", "failed", "disabled" :param registration_state : Details like device id, assigned hub , date times etc returned from the provisioning service. """ - self._request_id = request_id self._operation_id = operation_id self._status = status self._registration_state = registration_state - @property - def request_id(self): - return self._request_id - @property def operation_id(self): return self._operation_id @@ -70,6 +65,7 @@ class RegistrationState(object): created_date_time=None, last_update_date_time=None, etag=None, + payload=None, ): """ :param device_id: Desired device id for the provisioned device @@ -79,6 +75,7 @@ class RegistrationState(object): :param created_date_time: Registration create date time (in UTC). :param last_update_date_time: Last updated date time (in UTC). :param etag: The entity tag associated with the resource. + :param payload: The payload with which hub is responding """ self._device_id = device_id self._assigned_hub = assigned_hub @@ -86,6 +83,7 @@ class RegistrationState(object): self._created_date_time = created_date_time self._last_update_date_time = last_update_date_time self._etag = etag + self._response_payload = payload @property def device_id(self): @@ -111,5 +109,11 @@ class RegistrationState(object): def etag(self): return self._etag + @property + def response_payload(self): + return json.dumps(self._response_payload, default=lambda o: o.__dict__, sort_keys=True) + def __str__(self): - return "\n".join([self.device_id, self.assigned_hub, self.sub_status]) + return "\n".join( + [self.device_id, self.assigned_hub, self.sub_status, self.response_payload] + ) diff --git a/azure-iot-device/azure/iot/device/provisioning/pipeline/__init__.py b/azure-iot-device/azure/iot/device/provisioning/pipeline/__init__.py index 1089b6ae4..3a9ac5918 100644 --- a/azure-iot-device/azure/iot/device/provisioning/pipeline/__init__.py +++ b/azure-iot-device/azure/iot/device/provisioning/pipeline/__init__.py @@ -5,3 +5,4 @@ This package provides pipeline for use with the Azure Provisioning Device SDK. INTERNAL USAGE ONLY """ from .provisioning_pipeline import ProvisioningPipeline +from .config import ProvisioningPipelineConfig diff --git a/azure-iot-device/azure/iot/device/provisioning/pipeline/config.py b/azure-iot-device/azure/iot/device/provisioning/pipeline/config.py new file mode 100644 index 000000000..2a6342c04 --- /dev/null +++ b/azure-iot-device/azure/iot/device/provisioning/pipeline/config.py @@ -0,0 +1,17 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import logging +from azure.iot.device.common.pipeline.config import BasePipelineConfig + +logger = logging.getLogger(__name__) + + +class ProvisioningPipelineConfig(BasePipelineConfig): + """A class for storing all configurations/options for Provisioning clients in the Azure IoT Python Device Client Library. + """ + + pass diff --git a/azure-iot-device/azure/iot/device/provisioning/pipeline/exceptions.py b/azure-iot-device/azure/iot/device/provisioning/pipeline/exceptions.py new file mode 100644 index 000000000..bfa22a2b5 --- /dev/null +++ b/azure-iot-device/azure/iot/device/provisioning/pipeline/exceptions.py @@ -0,0 +1,21 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +"""This module defines an exception surface, exposed as part of the pipeline API""" + +# For now, present relevant transport errors as part of the Pipeline API surface +# so that they do not have to be duplicated at this layer. +# OK TODO This mimics the IotHub Case. Both IotHub and Provisioning needs to change +from azure.iot.device.common.pipeline.pipeline_exceptions import * +from azure.iot.device.common.transport_exceptions import ( + ConnectionFailedError, + ConnectionDroppedError, + # CT TODO: UnauthorizedError (the one from transport) should probably not surface out of + # the pipeline due to confusion with the higher level service UnauthorizedError. It + # should probably get turned into some other error instead (e.g. ConnectionFailedError). + # But for now, this is a stopgap. + UnauthorizedError, + ProtocolClientError, +) diff --git a/azure-iot-device/azure/iot/device/provisioning/pipeline/mqtt_topic.py b/azure-iot-device/azure/iot/device/provisioning/pipeline/mqtt_topic.py index 1131cde65..85581f8ef 100644 --- a/azure-iot-device/azure/iot/device/provisioning/pipeline/mqtt_topic.py +++ b/azure-iot-device/azure/iot/device/provisioning/pipeline/mqtt_topic.py @@ -24,24 +24,24 @@ def get_topic_for_subscribe(): return _get_topic_base() + "res/#" -def get_topic_for_register(request_id): +def get_topic_for_register(method, request_id): """ return the topic string used to publish telemetry """ - return (_get_topic_base() + "PUT/iotdps-register/?$rid={request_id}").format( - request_id=request_id + return (_get_topic_base() + "{method}/iotdps-register/?$rid={request_id}").format( + method=method, request_id=request_id ) -def get_topic_for_query(request_id, operation_id): +def get_topic_for_query(method, request_id, operation_id): """ :return: The topic for cloud to device messages.It is of the format "devices//messages/devicebound/#" """ return ( _get_topic_base() - + "GET/iotdps-get-operationstatus/?$rid={request_id}&operationId={operation_id}" - ).format(request_id=request_id, operation_id=operation_id) + + "{method}/iotdps-get-operationstatus/?$rid={request_id}&operationId={operation_id}" + ).format(method=method, request_id=request_id, operation_id=operation_id) def get_topic_for_response(): @@ -93,3 +93,22 @@ def extract_status_code_from_topic(topic): url_parts = topic_parts[1].split("/") status_code = url_parts[POS_STATUS_CODE_IN_TOPIC] return status_code + + +def get_optional_element(content, element_name, index=0): + """ + Gets an optional element from json string , or dictionary. + :param content: The content from which the element needs to be retrieved. + :param element_name: The name of the element + :param index: Optional index in case the return is a collection of elements. + """ + element = None if element_name not in content else content[element_name] + if element is None: + return None + else: + if isinstance(element, list): + return element[index] + elif isinstance(element, object): + return element + else: + return str(element) diff --git a/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_events_provisioning.py b/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_events_provisioning.py deleted file mode 100644 index 6b9bd2dd0..000000000 --- a/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_events_provisioning.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -from azure.iot.device.common.pipeline.pipeline_events_base import PipelineEvent - - -class RegistrationResponseEvent(PipelineEvent): - """ - A PipelineEvent object which represents an incoming RegistrationResponse event. This object is probably - created by some converter stage based on a pipeline-specific event - """ - - def __init__(self, request_id, status_code, key_values, response_payload): - """ - Initializer for RegistrationResponse objects. - :param request_id : The id of the request to which the response arrived. - :param status_code: The status code received in the topic. - :param key_values: A dictionary containing key mapped to a list of values that were extarcted from the topic. - :param response_payload: The response received from a registration process - """ - super(RegistrationResponseEvent, self).__init__() - self.request_id = request_id - self.status_code = status_code - self.key_values = key_values - self.response_payload = response_payload diff --git a/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_ops_provisioning.py b/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_ops_provisioning.py index f4d56fd21..eb30f6d1a 100644 --- a/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_ops_provisioning.py +++ b/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_ops_provisioning.py @@ -16,7 +16,7 @@ class SetSymmetricKeySecurityClientOperation(PipelineOperation): very provisioning-specific """ - def __init__(self, security_client, callback=None): + def __init__(self, security_client, callback): """ Initializer for SetSecurityClient. @@ -41,7 +41,7 @@ class SetX509SecurityClientOperation(PipelineOperation): (such as a Provisioning client). """ - def __init__(self, security_client, callback=None): + def __init__(self, security_client, callback): """ Initializer for SetSecurityClient. @@ -71,9 +71,9 @@ class SetProvisioningClientConnectionArgsOperation(PipelineOperation): provisioning_host, registration_id, id_scope, + callback, client_cert=None, sas_token=None, - callback=None, ): """ Initializer for SetProvisioningClientConnectionArgsOperation. @@ -91,7 +91,7 @@ class SetProvisioningClientConnectionArgsOperation(PipelineOperation): self.sas_token = sas_token -class SendRegistrationRequestOperation(PipelineOperation): +class RegisterOperation(PipelineOperation): """ A PipelineOperation object which contains arguments used to send a registration request to an Device Provisioning Service. @@ -99,22 +99,26 @@ class SendRegistrationRequestOperation(PipelineOperation): This operation is in the group of DPS operations because it is very specific to the DPS client. """ - def __init__(self, request_id, request_payload, callback=None): + def __init__(self, request_payload, registration_id, callback, registration_result=None): """ - Initializer for SendRegistrationRequestOperation objects. + Initializer for RegisterOperation objects. - :param request_id : The id of the request being sent :param request_payload: The request that we are sending to the service + :param registration_id: The registration ID is used to uniquely identify a device in the Device Provisioning Service. :param Function callback: The function that gets called when this operation is complete or has failed. The callback function must accept A PipelineOperation object which indicates the specific operation which has completed or failed. """ - super(SendRegistrationRequestOperation, self).__init__(callback=callback) - self.request_id = request_id + super(RegisterOperation, self).__init__(callback=callback) self.request_payload = request_payload + self.registration_id = registration_id + self.registration_result = registration_result + self.retry_after_timer = None + self.polling_timer = None + self.provisioning_timeout_timer = None -class SendQueryRequestOperation(PipelineOperation): +class PollStatusOperation(PipelineOperation): """ A PipelineOperation object which contains arguments used to send a registration request to an Device Provisioning Service. @@ -122,17 +126,20 @@ class SendQueryRequestOperation(PipelineOperation): This operation is in the group of DPS operations because it is very specific to the DPS client. """ - def __init__(self, request_id, operation_id, request_payload, callback=None): + def __init__(self, operation_id, request_payload, callback, registration_result=None): """ - Initializer for SendRegistrationRequestOperation objects. + Initializer for PollStatusOperation objects. - :param request_id + :param operation_id: The id of the existing operation for which the polling was started. :param request_payload: The request that we are sending to the service :param Function callback: The function that gets called when this operation is complete or has failed. The callback function must accept A PipelineOperation object which indicates the specific operation which has completed or failed. """ - super(SendQueryRequestOperation, self).__init__(callback=callback) - self.request_id = request_id + super(PollStatusOperation, self).__init__(callback=callback) self.operation_id = operation_id self.request_payload = request_payload + self.registration_result = registration_result + self.retry_after_timer = None + self.polling_timer = None + self.provisioning_timeout_timer = None diff --git a/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_stages_provisioning.py b/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_stages_provisioning.py index cd9baa0f5..42faea6cb 100644 --- a/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_stages_provisioning.py +++ b/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_stages_provisioning.py @@ -4,9 +4,23 @@ # license information. # -------------------------------------------------------------------------- -from azure.iot.device.common.pipeline import pipeline_ops_base, operation_flow, pipeline_thread +from azure.iot.device.common.pipeline import pipeline_ops_base, pipeline_thread from azure.iot.device.common.pipeline.pipeline_stages_base import PipelineStage from . import pipeline_ops_provisioning +from azure.iot.device import exceptions +from azure.iot.device.provisioning.pipeline import constant +from azure.iot.device.provisioning.models.registration_result import ( + RegistrationResult, + RegistrationState, +) +import logging +import weakref +import json +from threading import Timer +import time +from .mqtt_topic import get_optional_element + +logger = logging.getLogger(__name__) class UseSecurityClientStage(PipelineStage): @@ -18,33 +32,477 @@ class UseSecurityClientStage(PipelineStage): """ @pipeline_thread.runs_on_pipeline_thread - def _execute_op(self, op): + def _run_op(self, op): if isinstance(op, pipeline_ops_provisioning.SetSymmetricKeySecurityClientOperation): security_client = op.security_client - operation_flow.delegate_to_different_op( - stage=self, - original_op=op, - new_op=pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation( - provisioning_host=security_client.provisioning_host, - registration_id=security_client.registration_id, - id_scope=security_client.id_scope, - sas_token=security_client.get_current_sas_token(), - ), + worker_op = op.spawn_worker_op( + worker_op_type=pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation, + provisioning_host=security_client.provisioning_host, + registration_id=security_client.registration_id, + id_scope=security_client.id_scope, + sas_token=security_client.get_current_sas_token(), ) + self.send_op_down(worker_op) elif isinstance(op, pipeline_ops_provisioning.SetX509SecurityClientOperation): security_client = op.security_client - operation_flow.delegate_to_different_op( - stage=self, - original_op=op, - new_op=pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation( - provisioning_host=security_client.provisioning_host, - registration_id=security_client.registration_id, - id_scope=security_client.id_scope, - client_cert=security_client.get_x509_certificate(), - ), + worker_op = op.spawn_worker_op( + worker_op_type=pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation, + provisioning_host=security_client.provisioning_host, + registration_id=security_client.registration_id, + id_scope=security_client.id_scope, + client_cert=security_client.get_x509_certificate(), + ) + self.send_op_down(worker_op) + + else: + super(UseSecurityClientStage, self)._run_op(op) + + +class CommonProvisioningStage(PipelineStage): + """ + This is a super stage that the RegistrationStage and PollingStatusStage of + provisioning would both use. It contains some common functions like decoding response + and retrieving error, retrieving registration status, retrieving operation id + and forming a complete result. + """ + + @pipeline_thread.runs_on_pipeline_thread + def _clear_timeout_timer(self, op, error): + """ + Clearing timer for provisioning operations (Register and PollStatus) + when they respond back from service. + """ + if op.provisioning_timeout_timer: + logger.debug("{}({}): Cancelling provisioning timeout timer".format(self.name, op.name)) + op.provisioning_timeout_timer.cancel() + op.provisioning_timeout_timer = None + + @staticmethod + def _decode_response(provisioning_op): + return json.loads(provisioning_op.response_body.decode("utf-8")) + + @staticmethod + def _get_registration_status(decoded_response): + return get_optional_element(decoded_response, "status") + + @staticmethod + def _get_operation_id(decoded_response): + return get_optional_element(decoded_response, "operationId") + + @staticmethod + def _form_complete_result(operation_id, decoded_response, status): + """ + Create the registration result from the complete decoded json response for details regarding the registration process. + """ + decoded_state = get_optional_element(decoded_response, "registrationState") + registration_state = None + if decoded_state is not None: + registration_state = RegistrationState( + device_id=get_optional_element(decoded_state, "deviceId"), + assigned_hub=get_optional_element(decoded_state, "assignedHub"), + sub_status=get_optional_element(decoded_state, "substatus"), + created_date_time=get_optional_element(decoded_state, "createdDateTimeUtc"), + last_update_date_time=get_optional_element(decoded_state, "lastUpdatedDateTimeUtc"), + etag=get_optional_element(decoded_state, "etag"), + payload=get_optional_element(decoded_state, "payload"), + ) + + registration_result = RegistrationResult( + operation_id=operation_id, status=status, registration_state=registration_state + ) + return registration_result + + def _process_service_error_status_code(self, original_provisioning_op, request_response_op): + logger.error( + "{stage_name}({op_name}): Received error with status code {status_code} for {prov_op_name} request operation".format( + stage_name=self.name, + op_name=request_response_op.name, + prov_op_name=request_response_op.request_type, + status_code=request_response_op.status_code, + ) + ) + logger.error( + "{stage_name}({op_name}): Response body: {body}".format( + stage_name=self.name, + op_name=request_response_op.name, + body=request_response_op.response_body, + ) + ) + original_provisioning_op.complete( + error=exceptions.ServiceError( + "{prov_op_name} request returned a service error status code {status_code}".format( + prov_op_name=request_response_op.request_type, + status_code=request_response_op.status_code, + ) + ) + ) + + def _process_retry_status_code(self, error, original_provisioning_op, request_response_op): + retry_interval = ( + int(request_response_op.retry_after, 10) + if request_response_op.retry_after is not None + else constant.DEFAULT_POLLING_INTERVAL + ) + + self_weakref = weakref.ref(self) + + @pipeline_thread.invoke_on_pipeline_thread_nowait + def do_retry_after(): + this = self_weakref() + logger.info( + "{stage_name}({op_name}): retrying".format( + stage_name=this.name, op_name=request_response_op.name + ) + ) + original_provisioning_op.retry_after_timer.cancel() + original_provisioning_op.retry_after_timer = None + original_provisioning_op.completed = False + this.run_op(original_provisioning_op) + + logger.warning( + "{stage_name}({op_name}): Op needs retry with interval {interval} because of {error}. Setting timer.".format( + stage_name=self.name, + op_name=request_response_op.name, + interval=retry_interval, + error=error, + ) + ) + + logger.debug("{}({}): Creating retry timer".format(self.name, request_response_op.name)) + original_provisioning_op.retry_after_timer = Timer(retry_interval, do_retry_after) + original_provisioning_op.retry_after_timer.start() + + @staticmethod + def _process_failed_and_assigned_registration_status( + error, + operation_id, + decoded_response, + registration_status, + original_provisioning_op, + request_response_op, + ): + complete_registration_result = CommonProvisioningStage._form_complete_result( + operation_id=operation_id, decoded_response=decoded_response, status=registration_status + ) + original_provisioning_op.registration_result = complete_registration_result + if registration_status == "failed": + error = exceptions.ServiceError( + "Query Status operation returned a failed registration status with a status code of {status_code}".format( + status_code=request_response_op.status_code + ) + ) + original_provisioning_op.complete(error=error) + + @staticmethod + def _process_unknown_registration_status( + registration_status, original_provisioning_op, request_response_op + ): + error = exceptions.ServiceError( + "Query Status Operation encountered an invalid registration status {status} with a status code of {status_code}".format( + status=registration_status, status_code=request_response_op.status_code + ) + ) + original_provisioning_op.complete(error=error) + + +class PollingStatusStage(CommonProvisioningStage): + """ + This stage is responsible for sending the query request once initial response + is received from the registration response. + Upon the receipt of the response this stage decides whether + to send another query request or complete the procedure. + """ + + @pipeline_thread.runs_on_pipeline_thread + def _run_op(self, op): + if isinstance(op, pipeline_ops_provisioning.PollStatusOperation): + query_status_op = op + self_weakref = weakref.ref(self) + + @pipeline_thread.invoke_on_pipeline_thread_nowait + def query_timeout(): + this = self_weakref() + logger.info( + "{stage_name}({op_name}): returning timeout error".format( + stage_name=this.name, op_name=op.name + ) + ) + query_status_op.complete( + error=( + exceptions.ServiceError( + "Operation timed out before provisioning service could respond for {op_type} operation".format( + op_type=constant.QUERY + ) + ) + ) + ) + + logger.debug("{}({}): Creating provisioning timeout timer".format(self.name, op.name)) + query_status_op.provisioning_timeout_timer = Timer( + constant.DEFAULT_TIMEOUT_INTERVAL, query_timeout + ) + query_status_op.provisioning_timeout_timer.start() + + def on_query_response(op, error): + self._clear_timeout_timer(query_status_op, error) + logger.debug( + "{stage_name}({op_name}): Received response with status code {status_code} for PollStatusOperation with operation id {oper_id}".format( + stage_name=self.name, + op_name=op.name, + status_code=op.status_code, + oper_id=op.query_params["operation_id"], + ) + ) + + if error: + logger.error( + "{stage_name}({op_name}): Received error for {prov_op_name} operation".format( + stage_name=self.name, op_name=op.name, prov_op_name=op.request_type + ) + ) + query_status_op.complete(error=error) + + else: + if 300 <= op.status_code < 429: + self._process_service_error_status_code(query_status_op, op) + + elif op.status_code >= 429: + self._process_retry_status_code(error, query_status_op, op) + + else: + decoded_response = self._decode_response(op) + operation_id = self._get_operation_id(decoded_response) + registration_status = self._get_registration_status(decoded_response) + if registration_status == "assigning": + polling_interval = ( + int(op.retry_after, 10) + if op.retry_after is not None + else constant.DEFAULT_POLLING_INTERVAL + ) + self_weakref = weakref.ref(self) + + @pipeline_thread.invoke_on_pipeline_thread_nowait + def do_polling(): + this = self_weakref() + logger.info( + "{stage_name}({op_name}): retrying".format( + stage_name=this.name, op_name=op.name + ) + ) + query_status_op.polling_timer.cancel() + query_status_op.polling_timer = None + query_status_op.completed = False + this.run_op(query_status_op) + + logger.info( + "{stage_name}({op_name}): Op needs retry with interval {interval} because of {error}. Setting timer.".format( + stage_name=self.name, + op_name=op.name, + interval=polling_interval, + error=error, + ) + ) + + logger.debug( + "{}({}): Creating polling timer".format(self.name, op.name) + ) + query_status_op.polling_timer = Timer(polling_interval, do_polling) + query_status_op.polling_timer.start() + + elif registration_status == "assigned" or registration_status == "failed": + self._process_failed_and_assigned_registration_status( + error=error, + operation_id=operation_id, + decoded_response=decoded_response, + registration_status=registration_status, + original_provisioning_op=query_status_op, + request_response_op=op, + ) + + else: + self._process_unknown_registration_status( + registration_status=registration_status, + original_provisioning_op=query_status_op, + request_response_op=op, + ) + + self.send_op_down( + pipeline_ops_base.RequestAndResponseOperation( + request_type=constant.QUERY, + method="GET", + resource_location="/", + query_params={"operation_id": query_status_op.operation_id}, + request_body=query_status_op.request_payload, + callback=on_query_response, + ) ) else: - operation_flow.pass_op_to_next_stage(self, op) + super(PollingStatusStage, self)._run_op(op) + + +class RegistrationStage(CommonProvisioningStage): + """ + This is the first stage that decides converts a registration request + into a normal request and response operation. + Upon the receipt of the response this stage decides whether + to send another registration request or send a query request. + Depending on the status and result of the response + this stage may also complete the registration process. + """ + + @pipeline_thread.runs_on_pipeline_thread + def _run_op(self, op): + if isinstance(op, pipeline_ops_provisioning.RegisterOperation): + initial_register_op = op + self_weakref = weakref.ref(self) + + @pipeline_thread.invoke_on_pipeline_thread_nowait + def register_timeout(): + this = self_weakref() + logger.info( + "{stage_name}({op_name}): returning timeout error".format( + stage_name=this.name, op_name=op.name + ) + ) + initial_register_op.complete( + error=( + exceptions.ServiceError( + "Operation timed out before provisioning service could respond for {op_type} operation".format( + op_type=constant.REGISTER + ) + ) + ) + ) + + logger.debug("{}({}): Creating provisioning timeout timer".format(self.name, op.name)) + initial_register_op.provisioning_timeout_timer = Timer( + constant.DEFAULT_TIMEOUT_INTERVAL, register_timeout + ) + initial_register_op.provisioning_timeout_timer.start() + + def on_registration_response(op, error): + self._clear_timeout_timer(initial_register_op, error) + logger.debug( + "{stage_name}({op_name}): Received response with status code {status_code} for RegisterOperation".format( + stage_name=self.name, op_name=op.name, status_code=op.status_code + ) + ) + if error: + logger.error( + "{stage_name}({op_name}): Received error for {prov_op_name} operation".format( + stage_name=self.name, op_name=op.name, prov_op_name=op.request_type + ) + ) + initial_register_op.complete(error=error) + + else: + + if 300 <= op.status_code < 429: + self._process_service_error_status_code(initial_register_op, op) + + elif op.status_code >= 429: + self._process_retry_status_code(error, initial_register_op, op) + + else: + decoded_response = self._decode_response(op) + operation_id = self._get_operation_id(decoded_response) + registration_status = self._get_registration_status(decoded_response) + + if registration_status == "assigning": + self_weakref = weakref.ref(self) + + def copy_result_to_original_op(op, error): + logger.debug( + "Copying registration result from Query Status Op to Registration Op" + ) + initial_register_op.registration_result = op.registration_result + initial_register_op.error = error + + @pipeline_thread.invoke_on_pipeline_thread_nowait + def do_query_after_interval(): + this = self_weakref() + initial_register_op.polling_timer.cancel() + initial_register_op.polling_timer = None + + logger.info( + "{stage_name}({op_name}): polling".format( + stage_name=this.name, op_name=op.name + ) + ) + + query_worker_op = initial_register_op.spawn_worker_op( + worker_op_type=pipeline_ops_provisioning.PollStatusOperation, + request_payload=" ", + operation_id=operation_id, + callback=copy_result_to_original_op, + ) + + self.send_op_down(query_worker_op) + + logger.warning( + "{stage_name}({op_name}): Op will transition into polling after interval {interval}. Setting timer.".format( + stage_name=self.name, + op_name=op.name, + interval=constant.DEFAULT_POLLING_INTERVAL, + ) + ) + + logger.debug( + "{}({}): Creating polling timer".format(self.name, op.name) + ) + initial_register_op.polling_timer = Timer( + constant.DEFAULT_POLLING_INTERVAL, do_query_after_interval + ) + initial_register_op.polling_timer.start() + + elif registration_status == "failed" or registration_status == "assigned": + self._process_failed_and_assigned_registration_status( + error=error, + operation_id=operation_id, + decoded_response=decoded_response, + registration_status=registration_status, + original_provisioning_op=initial_register_op, + request_response_op=op, + ) + + else: + self._process_unknown_registration_status( + registration_status=registration_status, + original_provisioning_op=initial_register_op, + request_response_op=op, + ) + + registration_payload = DeviceRegistrationPayload( + registration_id=initial_register_op.registration_id, + custom_payload=initial_register_op.request_payload, + ) + self.send_op_down( + pipeline_ops_base.RequestAndResponseOperation( + request_type=constant.REGISTER, + method="PUT", + resource_location="/", + request_body=registration_payload.get_json_string(), + callback=on_registration_response, + ) + ) + + else: + super(RegistrationStage, self)._run_op(op) + + +class DeviceRegistrationPayload(object): + """ + The class representing the payload that needs to be sent to the service. + """ + + def __init__(self, registration_id, custom_payload=None): + # This is not a convention to name variables in python but the + # DPS service spec needs the name to be exact for it to work + self.registrationId = registration_id + self.payload = custom_payload + + def get_json_string(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True) diff --git a/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_stages_provisioning_mqtt.py b/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_stages_provisioning_mqtt.py index 35a280cb3..51c032950 100644 --- a/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_stages_provisioning_mqtt.py +++ b/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_stages_provisioning_mqtt.py @@ -10,32 +10,31 @@ from azure.iot.device.common.pipeline import ( pipeline_ops_base, pipeline_ops_mqtt, pipeline_events_mqtt, - operation_flow, pipeline_thread, + pipeline_events_base, ) from azure.iot.device.common.pipeline.pipeline_stages_base import PipelineStage from azure.iot.device.provisioning.pipeline import mqtt_topic -from azure.iot.device.provisioning.pipeline import ( - pipeline_events_provisioning, - pipeline_ops_provisioning, -) +from azure.iot.device.provisioning.pipeline import pipeline_ops_provisioning from azure.iot.device import constant as pkg_constant +from . import constant as pipeline_constant +from azure.iot.device.product_info import ProductInfo logger = logging.getLogger(__name__) -class ProvisioningMQTTConverterStage(PipelineStage): +class ProvisioningMQTTTranslationStage(PipelineStage): """ PipelineStage which converts other Provisioning pipeline operations into MQTT operations. This stage also converts MQTT pipeline events into Provisioning pipeline events. """ def __init__(self): - super(ProvisioningMQTTConverterStage, self).__init__() + super(ProvisioningMQTTTranslationStage, self).__init__() self.action_to_topic = {} @pipeline_thread.runs_on_pipeline_thread - def _execute_op(self, op): + def _run_op(self, op): if isinstance(op, pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation): # get security client args from above, save some, use some to build topic names, @@ -44,7 +43,7 @@ class ProvisioningMQTTConverterStage(PipelineStage): client_id = op.registration_id query_param_seq = [ ("api-version", pkg_constant.PROVISIONING_API_VERSION), - ("ClientVersion", pkg_constant.USER_AGENT), + ("ClientVersion", ProductInfo.get_provisioning_user_agent()), ] username = "{id_scope}/registrations/{registration_id}/{query_params}".format( id_scope=op.id_scope, @@ -54,61 +53,59 @@ class ProvisioningMQTTConverterStage(PipelineStage): hostname = op.provisioning_host - operation_flow.delegate_to_different_op( - stage=self, - original_op=op, - new_op=pipeline_ops_mqtt.SetMQTTConnectionArgsOperation( - client_id=client_id, - hostname=hostname, - username=username, - client_cert=op.client_cert, - sas_token=op.sas_token, - ), + worker_op = op.spawn_worker_op( + worker_op_type=pipeline_ops_mqtt.SetMQTTConnectionArgsOperation, + client_id=client_id, + hostname=hostname, + username=username, + client_cert=op.client_cert, + sas_token=op.sas_token, ) + self.send_op_down(worker_op) - elif isinstance(op, pipeline_ops_provisioning.SendRegistrationRequestOperation): - # Convert Sending the request into MQTT Publish operations - topic = mqtt_topic.get_topic_for_register(op.request_id) - operation_flow.delegate_to_different_op( - stage=self, - original_op=op, - new_op=pipeline_ops_mqtt.MQTTPublishOperation( - topic=topic, payload=op.request_payload - ), - ) - - elif isinstance(op, pipeline_ops_provisioning.SendQueryRequestOperation): - # Convert Sending the request into MQTT Publish operations - topic = mqtt_topic.get_topic_for_query(op.request_id, op.operation_id) - operation_flow.delegate_to_different_op( - stage=self, - original_op=op, - new_op=pipeline_ops_mqtt.MQTTPublishOperation( - topic=topic, payload=op.request_payload - ), - ) + elif isinstance(op, pipeline_ops_base.RequestOperation): + if op.request_type == pipeline_constant.REGISTER: + topic = mqtt_topic.get_topic_for_register( + method=op.method, request_id=op.request_id + ) + worker_op = op.spawn_worker_op( + worker_op_type=pipeline_ops_mqtt.MQTTPublishOperation, + topic=topic, + payload=op.request_body, + ) + self.send_op_down(worker_op) + else: + topic = mqtt_topic.get_topic_for_query( + method=op.method, + request_id=op.request_id, + operation_id=op.query_params["operation_id"], + ) + worker_op = op.spawn_worker_op( + worker_op_type=pipeline_ops_mqtt.MQTTPublishOperation, + topic=topic, + payload=op.request_body, + ) + self.send_op_down(worker_op) elif isinstance(op, pipeline_ops_base.EnableFeatureOperation): # Enabling for register gets translated into an MQTT subscribe operation topic = mqtt_topic.get_topic_for_subscribe() - operation_flow.delegate_to_different_op( - stage=self, - original_op=op, - new_op=pipeline_ops_mqtt.MQTTSubscribeOperation(topic=topic), + worker_op = op.spawn_worker_op( + worker_op_type=pipeline_ops_mqtt.MQTTSubscribeOperation, topic=topic ) + self.send_op_down(worker_op) elif isinstance(op, pipeline_ops_base.DisableFeatureOperation): # Disabling a register response gets turned into an MQTT unsubscribe operation topic = mqtt_topic.get_topic_for_subscribe() - operation_flow.delegate_to_different_op( - stage=self, - original_op=op, - new_op=pipeline_ops_mqtt.MQTTUnsubscribeOperation(topic=topic), + worker_op = op.spawn_worker_op( + worker_op_type=pipeline_ops_mqtt.MQTTUnsubscribeOperation, topic=topic ) + self.send_op_down(worker_op) else: # All other operations get passed down - operation_flow.pass_op_to_next_stage(self, op) + super(ProvisioningMQTTTranslationStage, self)._run_op(op) @pipeline_thread.runs_on_pipeline_thread def _handle_pipeline_event(self, event): @@ -126,22 +123,22 @@ class ProvisioningMQTTConverterStage(PipelineStage): ) ) key_values = mqtt_topic.extract_properties_from_topic(topic) + retry_after = mqtt_topic.get_optional_element(key_values, "retry-after", 0) status_code = mqtt_topic.extract_status_code_from_topic(topic) request_id = key_values["rid"][0] - if event.payload is not None: - response = event.payload.decode("utf-8") - # Extract pertinent information from mqtt topic - # like status code request_id and send it upwards. - operation_flow.pass_event_to_previous_stage( - self, - pipeline_events_provisioning.RegistrationResponseEvent( - request_id, status_code, key_values, response - ), + + self.send_event_up( + pipeline_events_base.ResponseEvent( + request_id=request_id, + status_code=int(status_code, 10), + response_body=event.payload, + retry_after=retry_after, + ) ) else: logger.warning("Unknown topic: {} passing up to next handler".format(topic)) - operation_flow.pass_event_to_previous_stage(self, event) + self.send_event_up(event) else: # all other messages get passed up - operation_flow.pass_event_to_previous_stage(self, event) + super(ProvisioningMQTTTranslationStage, self)._handle_pipeline_event(event) diff --git a/azure-iot-device/azure/iot/device/provisioning/pipeline/provisioning_pipeline.py b/azure-iot-device/azure/iot/device/provisioning/pipeline/provisioning_pipeline.py index b35e4042a..8c5eb51b4 100644 --- a/azure-iot-device/azure/iot/device/provisioning/pipeline/provisioning_pipeline.py +++ b/azure-iot-device/azure/iot/device/provisioning/pipeline/provisioning_pipeline.py @@ -13,47 +13,95 @@ from azure.iot.device.provisioning.pipeline import ( pipeline_stages_provisioning, pipeline_stages_provisioning_mqtt, ) -from azure.iot.device.provisioning.pipeline import pipeline_events_provisioning from azure.iot.device.provisioning.pipeline import pipeline_ops_provisioning from azure.iot.device.provisioning.security import SymmetricKeySecurityClient, X509SecurityClient +from azure.iot.device.provisioning.pipeline import constant as provisioning_constants logger = logging.getLogger(__name__) class ProvisioningPipeline(object): - def __init__(self, security_client): + def __init__(self, security_client, pipeline_configuration): """ Constructor for instantiating a pipeline :param security_client: The security client which stores credentials """ + self.responses_enabled = {provisioning_constants.REGISTER: False} + # Event Handlers - Will be set by Client after instantiation of pipeline self.on_connected = None self.on_disconnected = None self.on_message_received = None + self._registration_id = security_client.registration_id self._pipeline = ( - pipeline_stages_base.PipelineRootStage() + # + # The root is always the root. By definition, it's the first stage in the pipeline. + # + pipeline_stages_base.PipelineRootStage(pipeline_configuration=pipeline_configuration) + # + # UseSecurityClientStager comes near the root by default because it doesn't need to be after + # anything, but it does need to be before ProvisoningMQTTTranslationStage. + # .append_stage(pipeline_stages_provisioning.UseSecurityClientStage()) - .append_stage(pipeline_stages_provisioning_mqtt.ProvisioningMQTTConverterStage()) - .append_stage(pipeline_stages_base.EnsureConnectionStage()) - .append_stage(pipeline_stages_base.SerializeConnectOpsStage()) + # + # RegistrationStage needs to come early because this is the stage that converts registration + # or query requests into request and response objects which are used by later stages + # + .append_stage(pipeline_stages_provisioning.RegistrationStage()) + # + # PollingStatusStage needs to come after RegistrationStage because RegistrationStage counts + # on PollingStatusStage to poll until the registration is complete. + # + .append_stage(pipeline_stages_provisioning.PollingStatusStage()) + # + # CoordinateRequestAndResponseStage needs to be after RegistrationStage and PollingStatusStage + # because these 2 stages create the request ops that CoordinateRequestAndResponseStage + # is coordinating. It needs to be before ProvisioningMQTTTranslationStage because that stage + # operates on ops that CoordinateRequestAndResponseStage produces + # + .append_stage(pipeline_stages_base.CoordinateRequestAndResponseStage()) + # + # ProvisioningMQTTTranslationStage comes here because this is the point where we can translate + # all operations directly into MQTT. After this stage, only pipeline_stages_base stages + # are allowed because ProvisioningMQTTTranslationStage removes all the provisioning-ness from the ops + # + .append_stage(pipeline_stages_provisioning_mqtt.ProvisioningMQTTTranslationStage()) + # + # AutoConnectStage comes here because only MQTT ops have the need_connection flag set + # and this is the first place in the pipeline wherer we can guaranetee that all network + # ops are MQTT ops. + # + .append_stage(pipeline_stages_base.AutoConnectStage()) + # + # ReconnectStage needs to be after AutoConnectStage because ReconnectStage sets/clears + # the virtually_conencted flag and we want an automatic connection op to set this flag so + # we can reconnect autoconnect operations. + # + .append_stage(pipeline_stages_base.ReconnectStage()) + # + # ConnectionLockStage needs to be after ReconnectStage because we want any ops that + # ReconnectStage creates to go through the ConnectionLockStage gate + # + .append_stage(pipeline_stages_base.ConnectionLockStage()) + # + # RetryStage needs to be near the end because it's retrying low-level MQTT operations. + # + .append_stage(pipeline_stages_base.RetryStage()) + # + # OpTimeoutStage needs to be after RetryStage because OpTimeoutStage returns the timeout + # errors that RetryStage is watching for. + # + .append_stage(pipeline_stages_base.OpTimeoutStage()) + # + # MQTTTransportStage needs to be at the very end of the pipeline because this is where + # operations turn into network traffic + # .append_stage(pipeline_stages_mqtt.MQTTTransportStage()) ) def _on_pipeline_event(event): - if isinstance(event, pipeline_events_provisioning.RegistrationResponseEvent): - if self.on_message_received: - self.on_message_received( - event.request_id, - event.status_code, - event.key_values, - event.response_payload, - ) - else: - logger.warning("Provisioning event received with no handler. dropping.") - - else: - logger.warning("Dropping unknown pipeline event {}".format(event.name)) + logger.warning("Dropping unknown pipeline event {}".format(event.name)) def _on_connected(): if self.on_connected: @@ -82,24 +130,25 @@ class ProvisioningPipeline(object): self._pipeline.run_op(op) callback.wait_for_completion() - if op.error: - logger.error("{} failed: {}".format(op.name, op.error)) - raise op.error def connect(self, callback=None): """ Connect to the service. :param callback: callback which is called when the connection to the service is complete. + + The following exceptions are not "raised", but rather returned via the "error" parameter + when invoking "callback": + + :raises: :class:`azure.iot.device.provisioning.pipeline.exceptions.ConnectionFailedError` + :raises: :class:`azure.iot.device.provisioning.pipeline.exceptions.ConnectionDroppedError` + :raises: :class:`azure.iot.device.provisioning.pipeline.exceptions.UnauthorizedError` + :raises: :class:`azure.iot.device.provisioning.pipeline.exceptions.ProtocolClientError` """ logger.info("connect called") - def pipeline_callback(call): - if call.error: - # TODO we need error semantics on the client - exit(1) - if callback: - callback() + def pipeline_callback(op, error): + callback(error=error) self._pipeline.run_op(pipeline_ops_base.ConnectOperation(callback=pipeline_callback)) @@ -108,83 +157,60 @@ class ProvisioningPipeline(object): Disconnect from the service. :param callback: callback which is called when the connection to the service has been disconnected + + The following exceptions are not "raised", but rather returned via the "error" parameter + when invoking "callback": + + :raises: :class:`azure.iot.device.iothub.pipeline.exceptions.ProtocolClientError` """ logger.info("disconnect called") - def pipeline_callback(call): - if call.error: - # TODO we need error semantics on the client - exit(1) - if callback: - callback() + def pipeline_callback(op, error): + callback(error=error) self._pipeline.run_op(pipeline_ops_base.DisconnectOperation(callback=pipeline_callback)) - def send_request(self, request_id, request_payload, operation_id=None, callback=None): - """ - Send a request to the Device Provisioning Service. - :param request_id: The id of the request - :param request_payload: The request which is to be sent. - :param operation_id: The id of the operation. - :param callback: callback which is called when the message publish has been acknowledged by the service. - """ - - def pipeline_callback(call): - if call.error: - # TODO we need error semantics on the client - exit(1) - if callback: - callback() - - op = None - if operation_id is not None: - op = pipeline_ops_provisioning.SendQueryRequestOperation( - request_id=request_id, - operation_id=operation_id, - request_payload=request_payload, - callback=pipeline_callback, - ) - else: - op = pipeline_ops_provisioning.SendRegistrationRequestOperation( - request_id=request_id, request_payload=request_payload, callback=pipeline_callback - ) - - self._pipeline.run_op(op) - def enable_responses(self, callback=None): """ - Disable response from the DPS service by subscribing to the appropriate topics. + Enable response from the DPS service by subscribing to the appropriate topics. - :param callback: callback which is called when the feature is enabled + :param callback: callback which is called when responses are enabled """ logger.debug("enable_responses called") - def pipeline_callback(call): - if call.error: - # TODO we need error semantics on the client - exit(1) - if callback: - callback() + self.responses_enabled[provisioning_constants.REGISTER] = True + + def pipeline_callback(op, error): + callback(error=error) self._pipeline.run_op( pipeline_ops_base.EnableFeatureOperation(feature_name=None, callback=pipeline_callback) ) - def disable_responses(self, callback=None): + def register(self, payload=None, callback=None): """ - Disable response from the DPS service by unsubscribing from the appropriate topics. - :param callback: callback which is called when the feature is disabled + Register to the device provisioning service. + :param payload: Payload that can be sent with the registration request. + :param callback: callback which is called when the registration is done. + The following exceptions are not "raised", but rather returned via the "error" parameter + when invoking "callback": + + :raises: :class:`azure.iot.device.provisioning.pipeline.exceptions.ConnectionFailedError` + :raises: :class:`azure.iot.device.provisioning.pipeline.exceptions.ConnectionDroppedError` + :raises: :class:`azure.iot.device.provisioning.pipeline.exceptions.UnauthorizedError` + :raises: :class:`azure.iot.device.provisioning.pipeline.exceptions.ProtocolClientError` """ - logger.debug("disable_responses called") - def pipeline_callback(call): - if call.error: - # TODO we need error semantics on the client - exit(1) - if callback: - callback() + def on_complete(op, error): + # TODO : Apparently when its failed we can get result as well as error. + if error: + callback(error=error, result=None) + else: + callback(result=op.registration_result) self._pipeline.run_op( - pipeline_ops_base.DisableFeatureOperation(feature_name=None, callback=pipeline_callback) + pipeline_ops_provisioning.RegisterOperation( + request_payload=payload, registration_id=self._registration_id, callback=on_complete + ) ) diff --git a/azure-iot-device/azure/iot/device/provisioning/provisioning_device_client.py b/azure-iot-device/azure/iot/device/provisioning/provisioning_device_client.py index 5881cac03..fc94bb9f2 100644 --- a/azure-iot-device/azure/iot/device/provisioning/provisioning_device_client.py +++ b/azure-iot-device/azure/iot/device/provisioning/provisioning_device_client.py @@ -4,63 +4,93 @@ # license information. # -------------------------------------------------------------------------- """ -This module contains one of the implementations of the Provisioning Device Client which uses Symmetric Key authentication. +This module contains user-facing synchronous Provisioning Device Client for Azure Provisioning +Device SDK. This client uses Symmetric Key and X509 authentication to register devices with an +IoT Hub via the Device Provisioning Service. """ import logging from azure.iot.device.common.evented_callback import EventedCallback from .abstract_provisioning_device_client import AbstractProvisioningDeviceClient from .abstract_provisioning_device_client import log_on_register_complete -from .internal.polling_machine import PollingMachine +from azure.iot.device.provisioning.pipeline import constant as dps_constant +from .pipeline import exceptions as pipeline_exceptions +from azure.iot.device import exceptions + logger = logging.getLogger(__name__) +def handle_result(callback): + try: + return callback.wait_for_completion() + except pipeline_exceptions.ConnectionDroppedError as e: + raise exceptions.ConnectionDroppedError(message="Lost connection to IoTHub", cause=e) + except pipeline_exceptions.ConnectionFailedError as e: + raise exceptions.ConnectionFailedError(message="Could not connect to IoTHub", cause=e) + except pipeline_exceptions.UnauthorizedError as e: + raise exceptions.CredentialError(message="Credentials invalid, could not connect", cause=e) + except pipeline_exceptions.ProtocolClientError as e: + raise exceptions.ClientError(message="Error in the IoTHub client", cause=e) + except Exception as e: + raise exceptions.ClientError(message="Unexpected failure", cause=e) + + class ProvisioningDeviceClient(AbstractProvisioningDeviceClient): """ Client which can be used to run the registration of a device with provisioning service - using Symmetric Key authentication. + using Symmetric Key orr X509 authentication. """ - def __init__(self, provisioning_pipeline): - """ - Initializer for the Provisioning Client. - NOTE : This initializer should not be called directly. - Instead, the class methods that start with `create_from_` should be used to create a client object. - :param provisioning_pipeline: The protocol pipeline for provisioning. As of now this only supports MQTT. - """ - super(ProvisioningDeviceClient, self).__init__(provisioning_pipeline) - self._polling_machine = PollingMachine(provisioning_pipeline) - def register(self): """ - Register the device with the with thw provisioning service - This is a synchronous call, meaning that this function will not return until the registration - process has completed successfully or the attempt has resulted in a failure. Before returning - the client will also disconnect from the provisioning service. - If a registration attempt is made while a previous registration is in progress it may throw an error. + Register the device with the with the provisioning service + + This is a synchronous call, meaning that this function will not return until the + registration process has completed successfully or the attempt has resulted in a failure. + Before returning, the client will also disconnect from the provisioning service. + If a registration attempt is made while a previous registration is in progress it may + throw an error. + + :returns: RegistrationResult indicating the result of the registration. + :rtype: :class:`azure.iot.device.RegistrationResult` + + :raises: :class:`azure.iot.device.exceptions.CredentialError` if credentials are invalid + and a connection cannot be established. + :raises: :class:`azure.iot.device.exceptions.ConnectionFailedError` if a establishing a + connection results in failure. + :raises: :class:`azure.iot.device.exceptions.ConnectionDroppedError` if connection is lost + during execution. + :raises: :class:`azure.iot.device.exceptions.ClientError` if there is an unexpected failure + during execution. """ logger.info("Registering with Provisioning Service...") + if not self._provisioning_pipeline.responses_enabled[dps_constant.REGISTER]: + self._enable_responses() + register_complete = EventedCallback(return_arg_name="result") - self._polling_machine.register(callback=register_complete) - result = register_complete.wait_for_completion() + + self._provisioning_pipeline.register( + payload=self._provisioning_payload, callback=register_complete + ) + + result = handle_result(register_complete) log_on_register_complete(result) return result - def cancel(self): + def _enable_responses(self): + """Enable to receive responses from Device Provisioning Service. + + This is a synchronous call, meaning that this function will not return until the feature + has been enabled. + """ - This is a synchronous call, meaning that this function will not return until the cancellation - process has completed successfully or the attempt has resulted in a failure. Before returning - the client will also disconnect from the provisioning service. + logger.info("Enabling reception of response from Device Provisioning Service...") - In case there is no registration in process it will throw an error as there is - no registration process to cancel. - """ - logger.info("Cancelling the current registration process") + subscription_complete = EventedCallback() + self._provisioning_pipeline.enable_responses(callback=subscription_complete) - cancel_complete = EventedCallback() - self._polling_machine.cancel(callback=cancel_complete) - cancel_complete.wait_for_completion() + handle_result(subscription_complete) - logger.info("Successfully cancelled the current registration process") + logger.info("Successfully subscribed to Device Provisioning Service to receive responses") diff --git a/azure-iot-device/doc/images/azure_iot_sdk_python_banner.png b/azure-iot-device/doc/images/azure_iot_sdk_python_banner.png new file mode 100644 index 000000000..384757e3b Binary files /dev/null and b/azure-iot-device/doc/images/azure_iot_sdk_python_banner.png differ diff --git a/azure-iot-device/samples/README.md b/azure-iot-device/samples/README.md index 686ebc1a2..fc0f29392 100644 --- a/azure-iot-device/samples/README.md +++ b/azure-iot-device/samples/README.md @@ -11,6 +11,7 @@ This directory contains samples showing how to use the various features of the M ```bash az iot hub create --resource-group --name ``` + * Note that this operation make take a few minutes. 2. Add the IoT Extension to the Azure CLI, and then [register a device identity](https://docs.microsoft.com/en-us/cli/azure/ext/azure-cli-iot-ext/iot/hub/device-identity?view=azure-cli-latest#ext-azure-cli-iot-ext-az-iot-hub-device-identity-create) @@ -20,14 +21,15 @@ This directory contains samples showing how to use the various features of the M az iot hub device-identity create --hub-name --device-id ``` -2. [Retrieve your Device Connection String](https://docs.microsoft.com/en-us/cli/azure/ext/azure-cli-iot-ext/iot/hub/device-identity?view=azure-cli-latest#ext-azure-cli-iot-ext-az-iot-hub-device-identity-show-connection-string) using the Azure CLI +3. [Retrieve your Device Connection String](https://docs.microsoft.com/en-us/cli/azure/ext/azure-cli-iot-ext/iot/hub/device-identity?view=azure-cli-latest#ext-azure-cli-iot-ext-az-iot-hub-device-identity-show-connection-string) using the Azure CLI ```bash az iot hub device-identity show-connection-string --device-id --hub-name ``` It should be in the format: - ``` + + ```Text HostName=.azure-devices.net;DeviceId=;SharedAccessKey= ``` @@ -39,13 +41,16 @@ This directory contains samples showing how to use the various features of the M 5. On your device, set the Device Connection String as an enviornment variable called `IOTHUB_DEVICE_CONNECTION_STRING`. - ### Windows (cmd) + **Windows (cmd)** + ```cmd set IOTHUB_DEVICE_CONNECTION_STRING= ``` + * Note that there are **NO** quotation marks around the connection string. - ### Linux (bash) + **Linux (bash)** + ```bash export IOTHUB_DEVICE_CONNECTION_STRING="" ``` @@ -56,7 +61,6 @@ This directory contains samples showing how to use the various features of the M import os import asyncio from azure.iot.device.aio import IoTHubDeviceClient - from azure.iot.device import auth async def main(): @@ -94,17 +98,18 @@ This directory contains samples showing how to use the various features of the M 8. Your device is now able to connect to Azure IoT Hub! ## Additional Samples + Further samples with more complex IoT Hub scenarios are contained in the [advanced-hub-scenarios](advanced-hub-scenarios) directory, including: * Send multiple telemetry messages from a Device * Receive Cloud-to-Device (C2D) messages on a Device * Send and receive updates to device twin -* Receive direct method invocations +* Receive direct method invocations -Further samples with more complex IoT Edge scnearios are contained in the [advanced-edge-scenarios](advanced-edge-scenarios) directory, including: +Further samples with more complex IoT Edge scenarios are contained in the [advanced-edge-scenarios](advanced-edge-scenarios) directory, including: * Send multiple telemetry messages from a Module * Receive input messages on a Module * Send messages to a Module Output -Samples for the legacy clients, that use a synchronous API, intended for use with Python 2.7, Python 3.4, or compatibility scenarios for Python 3.5+ are contained in the [legacy-samples](legacy-samples) directory. \ No newline at end of file +Samples for the synchronous clients, intended for use with Python 2.7 or compatibility scenarios for Python 3.5+ are contained in the [sync-samples](sync-samples) directory. diff --git a/azure-iot-device/samples/advanced-hub-scenarios/README.md b/azure-iot-device/samples/advanced-hub-scenarios/README.md deleted file mode 100644 index e4c17aaad..000000000 --- a/azure-iot-device/samples/advanced-hub-scenarios/README.md +++ /dev/null @@ -1,59 +0,0 @@ -# Advanced IoT Hub Scenario Samples for the Azure IoT Hub Device SDK - -This directory contains samples showing how to use the various features of Azure IoT Hub Device SDK with the Azure IoT Hub. - -**These samples are written to run in Python 3.7+**, but can be made to work with Python 3.5 and 3.6 with a slight modification as noted in each sample: - -```python -if __name__ == "__main__": - asyncio.run(main()) - - # If using Python 3.6 or below, use the following code instead of asyncio.run(main()): - # loop = asyncio.get_event_loop() - # loop.run_until_complete(main()) - # loop.close() -``` - -## Included Samples - -### IoTHub Samples -In order to use these samples, you **must** set your Device Connection String in the environment variable `IOTHUB_DEVICE_CONNECTION_STRING`. - -* [send_message.py](send_message.py) - Send multiple telmetry messages in parallel from a device to the Azure IoT Hub. - * You can monitor the Azure IoT Hub for messages received by using the following Azure CLI command: - ```bash - az iot hub monitor-events --hub-name --output table - ``` -* [receive_message.py](receive_message.py) - Receive Cloud-to-Device (C2D) messages sent from the Azure IoT Hub to a device. - * In order to send a C2D message, use the following Azure CLI command: - ``` - az iot device c2d-message send --device-id --hub-name --data - ``` -* [receive_direct_method.py](receive_direct_method.py) - Receive direct method requests on a device from the Azure IoT Hub and send responses back - * In order to invoke a direct method, use the following Azure CLI command: - ``` - az iot hub invoke-device-method --device-id --hub-name --method-name - ``` -* [receive_twin_desired_properties_patch](receive_twin_desired_properties_patch.py) - Receive an update patch of changes made to the device twin's desired properties - * In order to send a update patch to a device twin's reported properties, use the following Azure CLI command: - ``` - az iot hub device-twin update --device-id --hub-name --set properties.desired.= - ``` -* [update_twin_reported_properties](update_twin_reported_properties.py) - Send an update patch of changes to the device twin's reported properties - * You can see the changes reflected in your device twin by using the following Azure CLI command: - ``` - az iot hub device-twin show --device-id --hub-name - ``` - - -### DPS Samples -In order to use these samples, you **must** have the following environment variables :- - -* PROVISIONING_HOST -* PROVISIONING_IDSCOPE -* PROVISIONING_REGISTRATION_ID - -There are 2 ways that your device can get registered to the provisioning service differing in authentication mechanisms and another additional environment variable is needed to for the samples:- - -* [register_symmetric_key.py](register_symmetric_key.py) - Register to provisioning service using a symmetric key. For this you must have the environment variable PROVISIONING_SYMMETRIC_KEY. -* [register_x509.py](register_x509.py) - Register to provisioning service using a symmetric key. For this you must have the environment variable X509_CERT_FILE, X509_KEY_FILE, PASS_PHRASE. \ No newline at end of file diff --git a/azure-iot-device/samples/advanced-edge-scenarios/README.md b/azure-iot-device/samples/async-edge-scenarios/README.md similarity index 100% rename from azure-iot-device/samples/advanced-edge-scenarios/README.md rename to azure-iot-device/samples/async-edge-scenarios/README.md diff --git a/azure-iot-device/samples/async-edge-scenarios/invoke_method_on_module.py b/azure-iot-device/samples/async-edge-scenarios/invoke_method_on_module.py new file mode 100644 index 000000000..2a50466cf --- /dev/null +++ b/azure-iot-device/samples/async-edge-scenarios/invoke_method_on_module.py @@ -0,0 +1,43 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- + +import asyncio +import time +import uuid +from azure.iot.device.aio import IoTHubModuleClient +from azure.iot.device import Message + +messages_to_send = 10 + + +async def main(): + # Inputs/Ouputs are only supported in the context of Azure IoT Edge and module client + # The module client object acts as an Azure IoT Edge module and interacts with an Azure IoT Edge hub + module_client = IoTHubModuleClient.create_from_edge_environment() + + # Connect the client. + await module_client.connect() + fake_method_params = { + "methodName": "doSomethingInteresting", + "payload": "foo", + "responseTimeoutInSeconds": 5, + "connectTimeoutInSeconds": 2, + } + response = await module_client.invoke_method( + device_id="fakeDeviceId", module_id="fakeModuleId", method_params=fake_method_params + ) + print("Method Response: {}".format(response)) + # finally, disconnect + module_client.disconnect() + + +if __name__ == "__main__": + asyncio.run(main()) + + # If using Python 3.6 or below, use the following code instead of asyncio.run(main()): + # loop = asyncio.get_event_loop() + # loop.run_until_complete(main()) + # loop.close() diff --git a/azure-iot-device/samples/advanced-edge-scenarios/receive_message_on_input.py b/azure-iot-device/samples/async-edge-scenarios/receive_message_on_input.py similarity index 100% rename from azure-iot-device/samples/advanced-edge-scenarios/receive_message_on_input.py rename to azure-iot-device/samples/async-edge-scenarios/receive_message_on_input.py diff --git a/azure-iot-device/samples/advanced-edge-scenarios/send_message.py b/azure-iot-device/samples/async-edge-scenarios/send_message.py similarity index 100% rename from azure-iot-device/samples/advanced-edge-scenarios/send_message.py rename to azure-iot-device/samples/async-edge-scenarios/send_message.py diff --git a/azure-iot-device/samples/advanced-edge-scenarios/send_message_to_output.py b/azure-iot-device/samples/async-edge-scenarios/send_message_to_output.py similarity index 100% rename from azure-iot-device/samples/advanced-edge-scenarios/send_message_to_output.py rename to azure-iot-device/samples/async-edge-scenarios/send_message_to_output.py diff --git a/azure-iot-device/samples/async-hub-scenarios/README.md b/azure-iot-device/samples/async-hub-scenarios/README.md new file mode 100644 index 000000000..dd3c98cc3 --- /dev/null +++ b/azure-iot-device/samples/async-hub-scenarios/README.md @@ -0,0 +1,83 @@ +# Advanced IoT Hub Scenario Samples for the Azure IoT Hub Device SDK + +This directory contains samples showing how to use the various features of Azure IoT Hub Device SDK with the Azure IoT Hub. + +**These samples are written to run in Python 3.7+**, but can be made to work with Python 3.5 and 3.6 with a slight modification as noted in each sample: + +```python +if __name__ == "__main__": + asyncio.run(main()) + + # If using Python 3.6 or below, use the following code instead of asyncio.run(main()): + # loop = asyncio.get_event_loop() + # loop.run_until_complete(main()) + # loop.close() +``` + +## Included Samples + +### IoTHub Samples + +In order to use these samples, you **must** set your Device Connection String in the environment variable `IOTHUB_DEVICE_CONNECTION_STRING`. + +* [send_message.py](send_message.py) - Send multiple telmetry messages in parallel from a device to the Azure IoT Hub. + * You can monitor the Azure IoT Hub for messages received by using the following Azure CLI command: + + ```bash + az iot hub monitor-events --hub-name --output table + ``` + +* [receive_message.py](receive_message.py) - Receive Cloud-to-Device (C2D) messages sent from the Azure IoT Hub to a device. + * In order to send a C2D message, use the following Azure CLI command: + + ```bash + az iot device c2d-message send --device-id --hub-name --data + ``` + +* [receive_direct_method.py](receive_direct_method.py) - Receive direct method requests on a device from the Azure IoT Hub and send responses back + * In order to invoke a direct method, use the following Azure CLI command: + + ```bash + az iot hub invoke-device-method --device-id --hub-name --method-name + ``` + +* [receive_twin_desired_properties_patch](receive_twin_desired_properties_patch.py) - Receive an update patch of changes made to the device twin's desired properties + * In order to send a update patch to a device twin's reported properties, use the following Azure CLI command: + + ```bash + az iot hub device-twin update --device-id --hub-name --set properties.desired.= + ``` + +* [update_twin_reported_properties](update_twin_reported_properties.py) - Send an update patch of changes to the device twin's reported properties + * You can see the changes reflected in your device twin by using the following Azure CLI command: + + ```bash + az iot hub device-twin show --device-id --hub-name + ``` + +### DPS Samples + +#### Individual + +In order to use these samples, you **must** have the following environment variables :- + +* PROVISIONING_HOST +* PROVISIONING_IDSCOPE +* PROVISIONING_REGISTRATION_ID + +There are 2 ways that your device can get registered to the provisioning service differing in authentication mechanisms and another additional environment variable is needed to for the samples:- + +* [provision_symmetric_key.py](provision_symmetric_key.py) - Provision a device to IoTHub by registering to the Device Provisioning Service using a symmetric key. For this you must have the environment variable PROVISIONING_SYMMETRIC_KEY. +* [provision_symmetric_key_and_send_telemetry.py](provision_symmetric_key_and_send_telemetry.py) - Provision a device to IoTHub by registering to the Device Provisioning Service using a symmetric key, then send a telemetry message to IoTHub. For this you must have the environment variable PROVISIONING_SYMMETRIC_KEY. +* [provision_symmetric_key_with_payload.py](provision_symmetric_key_with_payload.py) - Provision a device to IoTHub by registering to the Device Provisioning Service using a symmetric key while supplying a custom payload. For this you must have the environment variable PROVISIONING_SYMMETRIC_KEY. +* [provision_x509.py](provision_x509.py) - Provision a device to IoTHub by registering to the Device Provisioning Service using a symmetric key. For this you must have the environment variable X509_CERT_FILE, X509_KEY_FILE, PASS_PHRASE. +* [provision_x509_and_send_telemetry.py](provision_x509_and_send_telemetry.py) - Provision a device to IoTHub by registering to the Device Provisioning Service using a symmetric key, then send a telemetry message to IoTHub. For this you must have the environment variable X509_CERT_FILE, X509_KEY_FILE, PASS_PHRASE. + +#### Group + +In order to use these samples, you **must** have the following environment variables :- + +* PROVISIONING_HOST +* PROVISIONING_IDSCOPE + +* [provision_symmetric_key_group.py](provision_symmetric_key_group.py) - Provision multiple devices to IoTHub by registering them to the Device Provisioning Service using derived symmetric keys. For this you must have the environment variables PROVISIONING_MASTER_SYMMETRIC_KEY, PROVISIONING_DEVICE_ID_1, PROVISIONING_DEVICE_ID_2, PROVISIONING_DEVICE_ID_3. \ No newline at end of file diff --git a/azure-iot-device/samples/advanced-hub-scenarios/get_twin.py b/azure-iot-device/samples/async-hub-scenarios/get_twin.py similarity index 100% rename from azure-iot-device/samples/advanced-hub-scenarios/get_twin.py rename to azure-iot-device/samples/async-hub-scenarios/get_twin.py diff --git a/azure-iot-device/samples/advanced-hub-scenarios/register_symmetric_key.py b/azure-iot-device/samples/async-hub-scenarios/provision_symmetric_key.py similarity index 68% rename from azure-iot-device/samples/advanced-hub-scenarios/register_symmetric_key.py rename to azure-iot-device/samples/async-hub-scenarios/provision_symmetric_key.py index 730891706..f678a3cdb 100644 --- a/azure-iot-device/samples/advanced-hub-scenarios/register_symmetric_key.py +++ b/azure-iot-device/samples/async-hub-scenarios/provision_symmetric_key.py @@ -15,18 +15,15 @@ symmetric_key = os.getenv("PROVISIONING_SYMMETRIC_KEY") async def main(): - async def register_device(): - provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key( - provisioning_host=provisioning_host, - registration_id=registration_id, - id_scope=id_scope, - symmetric_key=symmetric_key, - ) + provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key( + provisioning_host=provisioning_host, + registration_id=registration_id, + id_scope=id_scope, + symmetric_key=symmetric_key, + ) - return await provisioning_device_client.register() + registration_result = await provisioning_device_client.register() - results = await asyncio.gather(register_device()) - registration_result = results[0] print("The complete registration result is") print(registration_result.registration_state) diff --git a/azure-iot-device/samples/async-hub-scenarios/provision_symmetric_key_and_send_telemetry.py b/azure-iot-device/samples/async-hub-scenarios/provision_symmetric_key_and_send_telemetry.py new file mode 100644 index 000000000..a343b3206 --- /dev/null +++ b/azure-iot-device/samples/async-hub-scenarios/provision_symmetric_key_and_send_telemetry.py @@ -0,0 +1,70 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import asyncio +from azure.iot.device.aio import ProvisioningDeviceClient +import os +from azure.iot.device.aio import IoTHubDeviceClient +from azure.iot.device import Message +import uuid + + +messages_to_send = 10 +provisioning_host = os.getenv("PROVISIONING_HOST") +id_scope = os.getenv("PROVISIONING_IDSCOPE") +registration_id = os.getenv("PROVISIONING_REGISTRATION_ID") +symmetric_key = os.getenv("PROVISIONING_SYMMETRIC_KEY") + + +async def main(): + provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key( + provisioning_host=provisioning_host, + registration_id=registration_id, + id_scope=id_scope, + symmetric_key=symmetric_key, + ) + + registration_result = await provisioning_device_client.register() + + print("The complete registration result is") + print(registration_result.registration_state) + + if registration_result.status == "assigned": + print("Will send telemetry from the provisioned device") + device_client = IoTHubDeviceClient.create_from_symmetric_key( + symmetric_key=symmetric_key, + hostname=registration_result.registration_state.assigned_hub, + device_id=registration_result.registration_state.device_id, + ) + # Connect the client. + await device_client.connect() + + async def send_test_message(i): + print("sending message #" + str(i)) + msg = Message("test wind speed " + str(i)) + msg.message_id = uuid.uuid4() + msg.correlation_id = "correlation-1234" + msg.custom_properties["count"] = i + msg.custom_properties["tornado-warning"] = "yes" + await device_client.send_message(msg) + print("done sending message #" + str(i)) + + # send `messages_to_send` messages in parallel + await asyncio.gather(*[send_test_message(i) for i in range(1, messages_to_send + 1)]) + + # finally, disconnect + await device_client.disconnect() + else: + print("Can not send telemetry from the provisioned device") + + +if __name__ == "__main__": + asyncio.run(main()) + + # If using Python 3.6 or below, use the following code instead of asyncio.run(main()): + # loop = asyncio.get_event_loop() + # loop.run_until_complete(main()) + # loop.close() diff --git a/azure-iot-device/samples/async-hub-scenarios/provision_symmetric_key_group.py b/azure-iot-device/samples/async-hub-scenarios/provision_symmetric_key_group.py new file mode 100644 index 000000000..4c8226fba --- /dev/null +++ b/azure-iot-device/samples/async-hub-scenarios/provision_symmetric_key_group.py @@ -0,0 +1,87 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import asyncio +import base64 +import hmac +import hashlib +from azure.iot.device.aio import ProvisioningDeviceClient + +provisioning_host = os.getenv("PROVISIONING_HOST") +id_scope = os.getenv("PROVISIONING_IDSCOPE") + +# These are the names of the devices that will eventually show up on the IoTHub +device_id_1 = os.getenv("PROVISIONING_DEVICE_ID_1") +device_id_2 = os.getenv("PROVISIONING_DEVICE_ID_2") +device_id_3 = os.getenv("PROVISIONING_DEVICE_ID_3") + +# For computation of device keys +device_ids_to_keys = {} + + +# NOTE : Only for illustration purposes. +# This is how a device key can be derived from the group symmetric key. +# This is just a helper function to show how it is done. +# Please don't directly store the group key on the device. +# Follow the following method to compute the device key somewhere else. + + +def derive_device_key(device_id, group_symmetric_key): + """ + The unique device ID and the group master key should be encoded into "utf-8" + After this the encoded group master key must be used to compute an HMAC-SHA256 of the encoded registration ID. + Finally the result must be converted into Base64 format. + The device key is the "utf-8" decoding of the above result. + """ + message = device_id.encode("utf-8") + signing_key = base64.b64decode(group_symmetric_key.encode("utf-8")) + signed_hmac = hmac.HMAC(signing_key, message, hashlib.sha256) + device_key_encoded = base64.b64encode(signed_hmac.digest()) + return device_key_encoded.decode("utf-8") + + +# derived_device_key has been computed already using the helper function somewhere else +# AND NOT on this sample. Do not use the direct master key on this sample to compute device key. +derived_device_key_1 = "some_value_already_computed" +derived_device_key_2 = "some_value_already_computed" +derived_device_key_3 = "some_value_already_computed" + + +device_ids_to_keys[device_id_1] = derived_device_key_1 +device_ids_to_keys[device_id_1] = derived_device_key_2 +device_ids_to_keys[device_id_1] = derived_device_key_3 + + +async def main(): + async def register_device(registration_id): + provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key( + provisioning_host=provisioning_host, + registration_id=registration_id, + id_scope=id_scope, + symmetric_key=device_ids_to_keys[registration_id], + ) + + return await provisioning_device_client.register() + + results = await asyncio.gather( + register_device(device_ids_to_keys[device_id_1]), + register_device(device_ids_to_keys[device_id_2]), + register_device(device_ids_to_keys[device_id_3]), + ) + for index in range(0, len(device_ids_to_keys)): + registration_result = results[index] + print("The complete state of registration result is") + print(registration_result.registration_state) + + +if __name__ == "__main__": + asyncio.run(main()) + + # If using Python 3.6 or below, use the following code instead of asyncio.run(main()): + # loop = asyncio.get_event_loop() + # loop.run_until_complete(main()) + # loop.close() diff --git a/azure-iot-device/samples/async-hub-scenarios/provision_symmetric_key_with_payload.py b/azure-iot-device/samples/async-hub-scenarios/provision_symmetric_key_with_payload.py new file mode 100644 index 000000000..d9290f3b2 --- /dev/null +++ b/azure-iot-device/samples/async-hub-scenarios/provision_symmetric_key_with_payload.py @@ -0,0 +1,47 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import asyncio +from azure.iot.device.aio import ProvisioningDeviceClient + +provisioning_host = os.getenv("PROVISIONING_HOST") +id_scope = os.getenv("PROVISIONING_IDSCOPE") +registration_id = os.getenv("PROVISIONING_REGISTRATION_ID_PAYLOAD") +symmetric_key = os.getenv("PROVISIONING_SYMMETRIC_KEY_PAYLOAD") + + +class Wizard(object): + def __init__(self, first_name, last_name, dict_of_stuff): + self.first_name = first_name + self.last_name = last_name + self.props = dict_of_stuff + + +async def main(): + provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key( + provisioning_host=provisioning_host, + registration_id=registration_id, + id_scope=id_scope, + symmetric_key=symmetric_key, + ) + + properties = {"House": "Gryffindor", "Muggle-Born": "False"} + wizard_a = Wizard("Harry", "Potter", properties) + provisioning_device_client.provisioning_payload = wizard_a + registration_result = await provisioning_device_client.register() + + print("The complete registration result is") + print(registration_result.registration_state) + + +if __name__ == "__main__": + asyncio.run(main()) + + # If using Python 3.6 or below, use the following code instead of asyncio.run(main()): + # loop = asyncio.get_event_loop() + # loop.run_until_complete(main()) + # loop.close() diff --git a/azure-iot-device/samples/advanced-hub-scenarios/register_x509.py b/azure-iot-device/samples/async-hub-scenarios/provision_x509.py similarity index 57% rename from azure-iot-device/samples/advanced-hub-scenarios/register_x509.py rename to azure-iot-device/samples/async-hub-scenarios/provision_x509.py index ef76f1025..34b58608c 100644 --- a/azure-iot-device/samples/advanced-hub-scenarios/register_x509.py +++ b/azure-iot-device/samples/async-hub-scenarios/provision_x509.py @@ -3,10 +3,6 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- - - -# This is for illustration purposes only. The sample will not work currently. - import os import asyncio from azure.iot.device import X509 @@ -18,23 +14,20 @@ registration_id = os.getenv("DPS_X509_REGISTRATION_ID") async def main(): - async def register_device(): - x509 = X509( - cert_file=os.getenv("X509_CERT_FILE"), - key_file=os.getenv("X509_KEY_FILE"), - pass_phrase=os.getenv("PASS_PHRASE"), - ) - provisioning_device_client = ProvisioningDeviceClient.create_from_x509_certificate( - provisioning_host=provisioning_host, - registration_id=registration_id, - id_scope=id_scope, - x509=x509, - ) + x509 = X509( + cert_file=os.getenv("X509_CERT_FILE"), + key_file=os.getenv("X509_KEY_FILE"), + pass_phrase=os.getenv("PASS_PHRASE"), + ) + provisioning_device_client = ProvisioningDeviceClient.create_from_x509_certificate( + provisioning_host=provisioning_host, + registration_id=registration_id, + id_scope=id_scope, + x509=x509, + ) - return await provisioning_device_client.register() + registration_result = await provisioning_device_client.register() - results = await asyncio.gather(register_device()) - registration_result = results[0] print("The complete registration result is") print(registration_result.registration_state) diff --git a/azure-iot-device/samples/async-hub-scenarios/provision_x509_and_send_telemetry.py b/azure-iot-device/samples/async-hub-scenarios/provision_x509_and_send_telemetry.py new file mode 100644 index 000000000..6ec9d3ddf --- /dev/null +++ b/azure-iot-device/samples/async-hub-scenarios/provision_x509_and_send_telemetry.py @@ -0,0 +1,76 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import os +import asyncio +from azure.iot.device import X509 +from azure.iot.device.aio import ProvisioningDeviceClient +from azure.iot.device.aio import IoTHubDeviceClient +from azure.iot.device import Message +import uuid + + +provisioning_host = os.getenv("PROVISIONING_HOST") +id_scope = os.getenv("PROVISIONING_IDSCOPE") +registration_id = os.getenv("DPS_X509_REGISTRATION_ID") +messages_to_send = 10 + + +async def main(): + x509 = X509( + cert_file=os.getenv("X509_CERT_FILE"), + key_file=os.getenv("X509_KEY_FILE"), + pass_phrase=os.getenv("PASS_PHRASE"), + ) + + provisioning_device_client = ProvisioningDeviceClient.create_from_x509_certificate( + provisioning_host=provisioning_host, + registration_id=registration_id, + id_scope=id_scope, + x509=x509, + ) + + registration_result = await provisioning_device_client.register() + + print("The complete registration result is") + print(registration_result.registration_state) + + if registration_result.status == "assigned": + print("Will send telemetry from the provisioned device") + device_client = IoTHubDeviceClient.create_from_x509_certificate( + x509=x509, + hostname=registration_result.registration_state.assigned_hub, + device_id=registration_result.registration_state.device_id, + ) + + # Connect the client. + await device_client.connect() + + async def send_test_message(i): + print("sending message #" + str(i)) + msg = Message("test wind speed " + str(i)) + msg.message_id = uuid.uuid4() + msg.correlation_id = "correlation-1234" + msg.custom_properties["count"] = i + msg.custom_properties["tornado-warning"] = "yes" + await device_client.send_message(msg) + print("done sending message #" + str(i)) + + # send `messages_to_send` messages in parallel + await asyncio.gather(*[send_test_message(i) for i in range(1, messages_to_send + 1)]) + + # finally, disconnect + await device_client.disconnect() + else: + print("Can not send telemetry from the provisioned device") + + +if __name__ == "__main__": + asyncio.run(main()) + + # If using Python 3.6 or below, use the following code instead of asyncio.run(main()): + # loop = asyncio.get_event_loop() + # loop.run_until_complete(main()) + # loop.close() diff --git a/azure-iot-device/samples/advanced-hub-scenarios/receive_direct_method.py b/azure-iot-device/samples/async-hub-scenarios/receive_direct_method.py similarity index 100% rename from azure-iot-device/samples/advanced-hub-scenarios/receive_direct_method.py rename to azure-iot-device/samples/async-hub-scenarios/receive_direct_method.py diff --git a/azure-iot-device/samples/advanced-hub-scenarios/receive_message.py b/azure-iot-device/samples/async-hub-scenarios/receive_message.py similarity index 100% rename from azure-iot-device/samples/advanced-hub-scenarios/receive_message.py rename to azure-iot-device/samples/async-hub-scenarios/receive_message.py diff --git a/azure-iot-device/samples/advanced-hub-scenarios/receive_message_x509.py b/azure-iot-device/samples/async-hub-scenarios/receive_message_x509.py similarity index 100% rename from azure-iot-device/samples/advanced-hub-scenarios/receive_message_x509.py rename to azure-iot-device/samples/async-hub-scenarios/receive_message_x509.py diff --git a/azure-iot-device/samples/advanced-hub-scenarios/receive_twin_desired_properties_patch.py b/azure-iot-device/samples/async-hub-scenarios/receive_twin_desired_properties_patch.py similarity index 100% rename from azure-iot-device/samples/advanced-hub-scenarios/receive_twin_desired_properties_patch.py rename to azure-iot-device/samples/async-hub-scenarios/receive_twin_desired_properties_patch.py diff --git a/azure-iot-device/samples/advanced-hub-scenarios/send_message.py b/azure-iot-device/samples/async-hub-scenarios/send_message.py similarity index 100% rename from azure-iot-device/samples/advanced-hub-scenarios/send_message.py rename to azure-iot-device/samples/async-hub-scenarios/send_message.py diff --git a/azure-iot-device/samples/async-hub-scenarios/send_message_over_websockets.py b/azure-iot-device/samples/async-hub-scenarios/send_message_over_websockets.py new file mode 100644 index 000000000..909174935 --- /dev/null +++ b/azure-iot-device/samples/async-hub-scenarios/send_message_over_websockets.py @@ -0,0 +1,35 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import asyncio +from azure.iot.device.aio import IoTHubDeviceClient + + +async def main(): + # Fetch the connection string from an enviornment variable + conn_str = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") + + # Create instance of the device client using the connection string + device_client = IoTHubDeviceClient.create_from_connection_string(conn_str, websockets=True) + + # We do not need to call device_client.connect(), since it will be connected when we send a message. + + # Send a single message + print("Sending message...") + await device_client.send_message("This is a message that is being sent") + print("Message successfully sent!") + + # Finally, we do not need a disconnect. When the program completes, the client will be disconnected and destroyed. + + +if __name__ == "__main__": + asyncio.run(main()) + + # If using Python 3.6 or below, use the following code instead of asyncio.run(main()): + # loop = asyncio.get_event_loop() + # loop.run_until_complete(main()) + # loop.close() diff --git a/azure-iot-device/samples/advanced-hub-scenarios/send_message_via_module_x509.py b/azure-iot-device/samples/async-hub-scenarios/send_message_via_module_x509.py similarity index 100% rename from azure-iot-device/samples/advanced-hub-scenarios/send_message_via_module_x509.py rename to azure-iot-device/samples/async-hub-scenarios/send_message_via_module_x509.py diff --git a/azure-iot-device/samples/async-hub-scenarios/send_message_via_proxy.py b/azure-iot-device/samples/async-hub-scenarios/send_message_via_proxy.py new file mode 100644 index 000000000..39dfdb25d --- /dev/null +++ b/azure-iot-device/samples/async-hub-scenarios/send_message_via_proxy.py @@ -0,0 +1,55 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import asyncio +import uuid +from azure.iot.device.aio import IoTHubDeviceClient +from azure.iot.device import Message, ProxyOptions +import socks + +messages_to_send = 10 + + +async def main(): + # The connection string for a device should never be stored in code. For the sake of simplicity we're using an environment variable here. + conn_str = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") + + proxy_opts = ProxyOptions( + proxy_type=socks.HTTP, proxy_addr="127.0.0.1", proxy_port=8888 # localhost + ) + + # The client object is used to interact with your Azure IoT hub. + device_client = IoTHubDeviceClient.create_from_connection_string( + conn_str, websockets=True, proxy_options=proxy_opts + ) + + # Connect the client. + await device_client.connect() + + async def send_test_message(i): + print("sending message #" + str(i)) + msg = Message("test wind speed " + str(i)) + msg.message_id = uuid.uuid4() + msg.correlation_id = "correlation-1234" + msg.custom_properties["tornado-warning"] = "yes" + await device_client.send_message(msg) + print("done sending message #" + str(i)) + + # send `messages_to_send` messages in parallel + await asyncio.gather(*[send_test_message(i) for i in range(1, messages_to_send + 1)]) + + # finally, disconnect + await device_client.disconnect() + + +if __name__ == "__main__": + asyncio.run(main()) + + # If using Python 3.6 or below, use the following code instead of asyncio.run(main()): + # loop = asyncio.get_event_loop() + # loop.run_until_complete(main()) + # loop.close() diff --git a/azure-iot-device/samples/advanced-hub-scenarios/send_message_x509.py b/azure-iot-device/samples/async-hub-scenarios/send_message_x509.py similarity index 100% rename from azure-iot-device/samples/advanced-hub-scenarios/send_message_x509.py rename to azure-iot-device/samples/async-hub-scenarios/send_message_x509.py diff --git a/azure-iot-device/samples/advanced-hub-scenarios/update_twin_reported_properties.py b/azure-iot-device/samples/async-hub-scenarios/update_twin_reported_properties.py similarity index 100% rename from azure-iot-device/samples/advanced-hub-scenarios/update_twin_reported_properties.py rename to azure-iot-device/samples/async-hub-scenarios/update_twin_reported_properties.py diff --git a/azure-iot-device/samples/async-hub-scenarios/upload_to_blob.py b/azure-iot-device/samples/async-hub-scenarios/upload_to_blob.py new file mode 100644 index 000000000..8b93aa26e --- /dev/null +++ b/azure-iot-device/samples/async-hub-scenarios/upload_to_blob.py @@ -0,0 +1,121 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import uuid +import asyncio +from azure.iot.device.aio import IoTHubDeviceClient, IoTHubModuleClient +from azure.iot.device import X509 +import http.client +import pprint +import json +from azure.storage.blob import BlobServiceClient, BlobClient, ContainerClient +import logging + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +""" +Welcome to the Upload to Blob sample. To use this sample you must have azure.storage.blob installed in your python environment. +To do this, you can run: + + $ pip isntall azure.storage.blob + +This sample covers using the following Device Client APIs: + + get_storage_info_for_blob + - used to get relevant information from IoT Hub about a linked Storage Account, including + a hostname, a container name, a blob name, and a sas token. Additionally it returns a correlation_id + which is used in the notify_blob_upload_status, since the correlation_id is IoT Hub's way of marking + which blob you are working on. + notify_blob_upload_status + - used to notify IoT Hub of the status of your blob storage operation. This uses the correlation_id obtained + by the get_storage_info_for_blob task, and will tell IoT Hub to notify any service that might be listening for a notification on the + status of the file upload task. + +You can learn more about File Upload with IoT Hub here: + +https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-file-upload + +""" + +# Host is in format ".azure-devices.net" + + +async def storage_blob(blob_info): + try: + print("Azure Blob storage v12 - Python quickstart sample") + sas_url = "https://{}/{}/{}{}".format( + blob_info["hostName"], + blob_info["containerName"], + blob_info["blobName"], + blob_info["sasToken"], + ) + blob_client = BlobClient.from_blob_url(sas_url) + # Create a file in local Documents directory to upload and download + local_file_name = "data/quickstart" + str(uuid.uuid4()) + ".txt" + filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), local_file_name) + # Write text to the file + if not os.path.exists(os.path.dirname(filename)): + os.makedirs(os.path.dirname(filename)) + file = open(filename, "w") + file.write("Hello, World!") + file.close() + + print("\nUploading to Azure Storage as blob:\n\t" + local_file_name) + # # Upload the created file + with open(filename, "rb") as f: + result = blob_client.upload_blob(f) + return (None, result) + + except Exception as ex: + print("Exception:") + print(ex) + return ex + + +async def main(): + hostname = os.getenv("IOTHUB_HOSTNAME") + device_id = os.getenv("IOTHUB_DEVICE_ID") + x509 = X509( + cert_file=os.getenv("X509_CERT_FILE"), + key_file=os.getenv("X509_KEY_FILE"), + pass_phrase=os.getenv("PASS_PHRASE"), + ) + + device_client = IoTHubDeviceClient.create_from_x509_certificate( + hostname=hostname, device_id=device_id, x509=x509 + ) + # device_client = IoTHubModuleClient.create_from_connection_string(conn_str) + + # Connect the client. + await device_client.connect() + + # await device_client.get_storage_info_for_blob("fake_device", "fake_method_params") + + # get the storage sas + blob_name = "fakeBlobName12" + storage_info = await device_client.get_storage_info_for_blob(blob_name) + + # upload to blob + connection = http.client.HTTPSConnection(hostname) + connection.connect() + # notify iot hub of blob upload result + # await device_client.notify_upload_result(storage_blob_result) + storage_blob_result = await storage_blob(storage_info) + pp = pprint.PrettyPrinter(indent=4) + pp.pprint(storage_blob_result) + connection.close() + await device_client.notify_blob_upload_status( + storage_info["correlationId"], True, 200, "fake status description" + ) + + # Finally, disconnect + await device_client.disconnect() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/azure-iot-device/samples/legacy-samples/README.md b/azure-iot-device/samples/legacy-samples/README.md deleted file mode 100644 index f1a48b069..000000000 --- a/azure-iot-device/samples/legacy-samples/README.md +++ /dev/null @@ -1,54 +0,0 @@ -# Legacy Scenario Samples for the Azure IoT Hub Device SDK - -This directory contains samples showing how to use the various features of Azure IoT Hub Device SDK with the Azure IoT Hub and Azure IoT Edge. - -**These samples are legacy samples**, they use the sycnhronous API intended for use with Python 2.7 and 3.4, or in compatibility scenarios with later versions. We recommend you use the [asynchronous API instead](../advanced-hub-scenarios). - - -## IoTHub Device Samples -In order to use these samples, you **must** set your Device Connection String in the environment variable `IOTHUB_DEVICE_CONNECTION_STRING`. - -* [send_message.py](send_message.py) - Send multiple telmetry messages in parallel from a device to the Azure IoT Hub. - * You can monitor the Azure IoT Hub for messages received by using the following Azure CLI command: - ```bash - az iot hub monitor-events --hub-name --output table - ``` -* [receive_message.py](receive_message.py) - Receive Cloud-to-Device (C2D) messages sent from the Azure IoT Hub to a device. - * In order to send a C2D message, use the following Azure CLI command: - ``` - az iot device c2d-message send --device-id --hub-name --data - ``` -* [receive_direct_method.py](receive_direct_method.py) - Receive direct method requests on a device from the Azure IoT Hub and send responses back - * In order to invoke a direct method, use the following Azure CLI command: - ``` - az iot hub invoke-device-method --device-id --hub-name --method-name - ``` -* [receive_twin_desired_properties_patch](receive_twin_desired_properties_patch.py) - Receive an update patch of changes made to the device twin's desired properties - * In order to send a update patch to a device twin's reported properties, use the following Azure CLI command: - ``` - az iot hub device-twin update --device-id --hub-name --set properties.desired.= - ``` -* [update_twin_reported_properties](update_twin_reported_properties.py) - Send an update patch of changes to the device twin's reported properties - * You can see the changes reflected in your device twin by using the following Azure CLI command: - ``` - az iot hub device-twin show --device-id --hub-name - ``` - -## IoT Edge Module Samples -In order to use these samples, they **must** be run from inside an Edge container. - -* [receive_message_on_input.py](receive_message_on_input.py) - Receive messages sent to an Edge module on a specific module input. -* [send_message_to_output.py](send_message_to_output.py) - Send multiple messages in parallel from an Edge module to a specific output - -## DPS Samples - -In order to use these samples, you **must** have the following environment variables :- - -* PROVISIONING_HOST -* PROVISIONING_IDSCOPE -* PROVISIONING_REGISTRATION_ID - -There are 2 ways that your device can get registered to the provisioning service differing in authentication mechanisms and another additional environment variable is needed to for the samples:- - -* [register_symmetric_key.py](register_symmetric_key.py) - Register to provisioning service using a symmetric key. For this you must have the environment variable PROVISIONING_SYMMETRIC_KEY. -* [register_x509.py](register_x509.py) - Register to provisioning service using a symmetric key. For this you must have the environment variable X509_CERT_FILE, X509_KEY_FILE, PASS_PHRASE. diff --git a/azure-iot-device/samples/sync-samples/README.md b/azure-iot-device/samples/sync-samples/README.md new file mode 100644 index 000000000..c345cd57f --- /dev/null +++ b/azure-iot-device/samples/sync-samples/README.md @@ -0,0 +1,75 @@ +# Legacy Scenario Samples for the Azure IoT Hub Device SDK + +This directory contains samples showing how to use the various features of Azure IoT Hub Device SDK with the Azure IoT Hub and Azure IoT Edge. + +**These samples are legacy samples**, they use the sycnhronous API intended for use with Python 2.7, or in compatibility scenarios with later versions. We recommend you use the [asynchronous API instead](../advanced-hub-scenarios). + +## IoTHub Device Samples + +In order to use these samples, you **must** set your Device Connection String in the environment variable `IOTHUB_DEVICE_CONNECTION_STRING`. + +* [send_message.py](send_message.py) - Send multiple telmetry messages in parallel from a device to the Azure IoT Hub. + * You can monitor the Azure IoT Hub for messages received by using the following Azure CLI command: + + ```Shell + bash az iot hub monitor-events --hub-name --output table``` + +* [receive_message.py](receive_message.py) - Receive Cloud-to-Device (C2D) messages sent from the Azure IoT Hub to a device. + * In order to send a C2D message, use the following Azure CLI command: + + ```Shell + az iot device c2d-message send --device-id --hub-name --data + ``` + +* [receive_direct_method.py](receive_direct_method.py) - Receive direct method requests on a device from the Azure IoT Hub and send responses back + * In order to invoke a direct method, use the following Azure CLI command: + + ```Shell + az iot hub invoke-device-method --device-id --hub-name --method-name + ``` + +* [receive_twin_desired_properties_patch](receive_twin_desired_properties_patch.py) - Receive an update patch of changes made to the device twin's desired properties + * In order to send a update patch to a device twin's reported properties, use the following Azure CLI command: + + ```Shell + az iot hub device-twin update --device-id --hub-name --set properties.desired.= + ``` + +* [update_twin_reported_properties](update_twin_reported_properties.py) - Send an update patch of changes to the device twin's reported properties + * You can see the changes reflected in your device twin by using the following Azure CLI command: + + ```Shell + az iot hub device-twin show --device-id --hub-name + ``` + +## IoT Edge Module Samples + +In order to use these samples, they **must** be run from inside an Edge container. + +* [receive_message_on_input.py](receive_message_on_input.py) - Receive messages sent to an Edge module on a specific module input. +* [send_message_to_output.py](send_message_to_output.py) - Send multiple messages in parallel from an Edge module to a specific output + +## DPS Samples + +### Individual + +In order to use these samples, you **must** have the following environment variables :- + +* PROVISIONING_HOST +* PROVISIONING_IDSCOPE +* PROVISIONING_REGISTRATION_ID + +There are 2 ways that your device can get registered to the provisioning service differing in authentication mechanisms and another additional environment variable is needed to for the samples:- + +* [provision_symmetric_key.py](provision_symmetric_key.py) - Provision a device to IoTHub by registering to the Device Provisioning Service using a symmetric key. For this you must have the environment variable PROVISIONING_SYMMETRIC_KEY. +* [provision_symmetric_key_with_payload.py](provision_symmetric_key_with_payload.py) - Provision a device to IoTHub by registering to the Device Provisioning Service using a symmetric key while supplying a custom payload. For this you must have the environment variable PROVISIONING_SYMMETRIC_KEY. +* [provision_x509.py](provision_x509.py) - Provision a device to IoTHub by registering to the Device Provisioning Service using a symmetric key. For this you must have the environment variable X509_CERT_FILE, X509_KEY_FILE, PASS_PHRASE. + +#### Group + +In order to use these samples, you **must** have the following environment variables :- + +* PROVISIONING_HOST +* PROVISIONING_IDSCOPE + +* [provision_symmetric_key_group.py](provision_symmetric_key_group.py) - Provision multiple devices to IoTHub by registering them to the Device Provisioning Service using derived symmetric keys. For this you must have the environment variables PROVISIONING_MASTER_SYMMETRIC_KEY, PROVISIONING_DEVICE_ID_1, PROVISIONING_DEVICE_ID_2, PROVISIONING_DEVICE_ID_3. \ No newline at end of file diff --git a/azure-iot-device/samples/legacy-samples/get_twin.py b/azure-iot-device/samples/sync-samples/get_twin.py similarity index 100% rename from azure-iot-device/samples/legacy-samples/get_twin.py rename to azure-iot-device/samples/sync-samples/get_twin.py diff --git a/azure-iot-device/samples/legacy-samples/register_symmetric_key.py b/azure-iot-device/samples/sync-samples/provision_symmetric_key.py similarity index 94% rename from azure-iot-device/samples/legacy-samples/register_symmetric_key.py rename to azure-iot-device/samples/sync-samples/provision_symmetric_key.py index c9d0124d8..9a6beae6f 100644 --- a/azure-iot-device/samples/legacy-samples/register_symmetric_key.py +++ b/azure-iot-device/samples/sync-samples/provision_symmetric_key.py @@ -23,7 +23,7 @@ registration_result = provisioning_device_client.register() print(registration_result) # Individual attributes can be seen as well -print("The request_id was :-") -print(registration_result.request_id) +print("The status was :-") +print(registration_result.status) print("The etag is :-") print(registration_result.registration_state.etag) diff --git a/azure-iot-device/samples/sync-samples/provision_symmetric_key_and_send_telemetry.py b/azure-iot-device/samples/sync-samples/provision_symmetric_key_and_send_telemetry.py new file mode 100644 index 000000000..2e9bd41c0 --- /dev/null +++ b/azure-iot-device/samples/sync-samples/provision_symmetric_key_and_send_telemetry.py @@ -0,0 +1,62 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from azure.iot.device import ProvisioningDeviceClient +import os +import time +from azure.iot.device import IoTHubDeviceClient, Message +import uuid + +provisioning_host = os.getenv("PROVISIONING_HOST") +id_scope = os.getenv("PROVISIONING_IDSCOPE") +registration_id = os.getenv("PROVISIONING_REGISTRATION_ID") +symmetric_key = os.getenv("PROVISIONING_SYMMETRIC_KEY") + +provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key( + provisioning_host=provisioning_host, + registration_id=registration_id, + id_scope=id_scope, + symmetric_key=symmetric_key, +) + +registration_result = provisioning_device_client.register() +# The result can be directly printed to view the important details. +print(registration_result) + +# Individual attributes can be seen as well +print("The request_id was :-") +print(registration_result.request_id) +print("The etag is :-") +print(registration_result.registration_state.etag) + +if registration_result.status == "assigned": + print("Will send telemetry from the provisioned device") + # Create device client from the above result + device_client = IoTHubDeviceClient.create_from_symmetric_key( + symmetric_key=symmetric_key, + hostname=registration_result.registration_state.assigned_hub, + device_id=registration_result.registration_state.device_id, + ) + + # Connect the client. + device_client.connect() + + for i in range(1, 6): + print("sending message #" + str(i)) + device_client.send_message("test payload message " + str(i)) + time.sleep(1) + + for i in range(6, 11): + print("sending message #" + str(i)) + msg = Message("test wind speed " + str(i)) + msg.message_id = uuid.uuid4() + msg.custom_properties["tornado-warning"] = "yes" + device_client.send_message(msg) + time.sleep(1) + + # finally, disconnect + device_client.disconnect() +else: + print("Can not send telemetry from the provisioned device") diff --git a/azure-iot-device/samples/sync-samples/provision_symmetric_key_group.py b/azure-iot-device/samples/sync-samples/provision_symmetric_key_group.py new file mode 100644 index 000000000..87ee023c3 --- /dev/null +++ b/azure-iot-device/samples/sync-samples/provision_symmetric_key_group.py @@ -0,0 +1,85 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import os +import base64 +import hmac +import hashlib +from azure.iot.device import ProvisioningDeviceClient + +provisioning_host = os.getenv("PROVISIONING_HOST") +id_scope = os.getenv("PROVISIONING_IDSCOPE") + +# These are the names of the devices that will eventually show up on the IoTHub +device_id_1 = os.getenv("PROVISIONING_DEVICE_ID_1") +device_id_2 = os.getenv("PROVISIONING_DEVICE_ID_2") +device_id_3 = os.getenv("PROVISIONING_DEVICE_ID_3") + +# For computation of device keys +device_ids_to_keys = {} + +# Keep a dictionary for results +results = {} + +# NOTE : Only for illustration purposes. +# This is how a device key can be derived from the group symmetric key. +# This is just a helper function to show how it is done. +# Please don't directly store the group key on the device. +# Follow the following method to compute the device key somewhere else. + + +def derive_device_key(device_id, group_symmetric_key): + """ + The unique device ID and the group master key should be encoded into "utf-8" + After this the encoded group master key must be used to compute an HMAC-SHA256 of the encoded registration ID. + Finally the result must be converted into Base64 format. + The device key is the "utf-8" decoding of the above result. + """ + message = device_id.encode("utf-8") + signing_key = base64.b64decode(group_symmetric_key.encode("utf-8")) + signed_hmac = hmac.HMAC(signing_key, message, hashlib.sha256) + device_key_encoded = base64.b64encode(signed_hmac.digest()) + return device_key_encoded.decode("utf-8") + + +# derived_device_key has been computed already using the helper function somewhere else +# AND NOT on this sample. Do not use the direct master key on this sample to compute device key. +derived_device_key_1 = "some_value_already_computed" +derived_device_key_2 = "some_value_already_computed" +derived_device_key_3 = "some_value_already_computed" + + +device_ids_to_keys[device_id_1] = derived_device_key_1 +device_ids_to_keys[device_id_1] = derived_device_key_2 +device_ids_to_keys[device_id_1] = derived_device_key_3 + + +def register_device(registration_id): + + provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key( + provisioning_host=provisioning_host, + registration_id=registration_id, + id_scope=id_scope, + symmetric_key=device_ids_to_keys[registration_id], + ) + + return provisioning_device_client.register() + + +for device_id in device_ids_to_keys: + registration_result = register_device(registration_id=device_id) + results[device_id] = registration_result + + +for device_id in device_ids_to_keys: + # The result can be directly printed to view the important details. + registration_result = results[device_id] + print(registration_result) + # Individual attributes can be seen as well + print("The request_id was :-") + print(registration_result.request_id) + print("The etag is :-") + print(registration_result.registration_state.etag) + print("\n") diff --git a/azure-iot-device/samples/sync-samples/provision_symmetric_key_with_payload.py b/azure-iot-device/samples/sync-samples/provision_symmetric_key_with_payload.py new file mode 100644 index 000000000..63a04f7a5 --- /dev/null +++ b/azure-iot-device/samples/sync-samples/provision_symmetric_key_with_payload.py @@ -0,0 +1,41 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import os +from azure.iot.device import ProvisioningDeviceClient + + +class Wizard(object): + def __init__(self, first_name, last_name, dict_of_stuff): + self.first_name = first_name + self.last_name = last_name + self.props = dict_of_stuff + + +provisioning_host = os.getenv("PROVISIONING_HOST") +id_scope = os.getenv("PROVISIONING_IDSCOPE") +registration_id = os.getenv("PROVISIONING_REGISTRATION_ID") +symmetric_key = os.getenv("PROVISIONING_SYMMETRIC_KEY") + +provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key( + provisioning_host=provisioning_host, + registration_id=registration_id, + id_scope=id_scope, + symmetric_key=symmetric_key, +) + +properties = {"House": "Gryffindor", "Muggle-Born": "False"} +wizard_a = Wizard("Harry", "Potter", properties) + +provisioning_device_client.provisioning_payload = wizard_a +registration_result = provisioning_device_client.register() +# The result can be directly printed to view the important details. +print(registration_result) + +# Individual attributes can be seen as well +print("The request_id was :-") +print(registration_result.request_id) +print("The etag is :-") +print(registration_result.registration_state.etag) diff --git a/azure-iot-device/samples/legacy-samples/register_x509.py b/azure-iot-device/samples/sync-samples/provision_x509.py similarity index 92% rename from azure-iot-device/samples/legacy-samples/register_x509.py rename to azure-iot-device/samples/sync-samples/provision_x509.py index 0494a2d7c..e809b5de0 100644 --- a/azure-iot-device/samples/legacy-samples/register_x509.py +++ b/azure-iot-device/samples/sync-samples/provision_x509.py @@ -3,10 +3,6 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- - - -# This is for illustration purposes only. The sample will not work currently. - import os from azure.iot.device import ProvisioningDeviceClient, X509 diff --git a/azure-iot-device/samples/sync-samples/provision_x509_and_send_telemetry.py b/azure-iot-device/samples/sync-samples/provision_x509_and_send_telemetry.py new file mode 100644 index 000000000..2222f72a6 --- /dev/null +++ b/azure-iot-device/samples/sync-samples/provision_x509_and_send_telemetry.py @@ -0,0 +1,63 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import os +from azure.iot.device import ProvisioningDeviceClient, X509 +import time +from azure.iot.device import IoTHubDeviceClient, Message +import uuid + + +provisioning_host = os.getenv("PROVISIONING_HOST") +id_scope = os.getenv("PROVISIONING_IDSCOPE") +registration_id = os.getenv("DPS_X509_REGISTRATION_ID") + +x509 = X509( + cert_file=os.getenv("X509_CERT_FILE"), + key_file=os.getenv("X509_KEY_FILE"), + pass_phrase=os.getenv("PASS_PHRASE"), +) + +provisioning_device_client = ProvisioningDeviceClient.create_from_x509_certificate( + provisioning_host=provisioning_host, + registration_id=registration_id, + id_scope=id_scope, + x509=x509, +) + +registration_result = provisioning_device_client.register() + +# The result can be directly printed to view the important details. +print(registration_result) + +if registration_result.status == "assigned": + print("Will send telemetry from the provisioned device") + # Create device client from the above result + device_client = IoTHubDeviceClient.create_from_x509_certificate( + x509=x509, + hostname=registration_result.registration_state.assigned_hub, + device_id=registration_result.registration_state.device_id, + ) + + # Connect the client. + device_client.connect() + + for i in range(1, 6): + print("sending message #" + str(i)) + device_client.send_message("test payload message " + str(i)) + time.sleep(1) + + for i in range(6, 11): + print("sending message #" + str(i)) + msg = Message("test wind speed " + str(i)) + msg.message_id = uuid.uuid4() + msg.custom_properties["tornado-warning"] = "yes" + device_client.send_message(msg) + time.sleep(1) + + # finally, disconnect + +else: + print("Can not send telemetry from the provisioned device") diff --git a/azure-iot-device/samples/legacy-samples/receive_direct_method.py b/azure-iot-device/samples/sync-samples/receive_direct_method.py similarity index 100% rename from azure-iot-device/samples/legacy-samples/receive_direct_method.py rename to azure-iot-device/samples/sync-samples/receive_direct_method.py diff --git a/azure-iot-device/samples/legacy-samples/receive_message.py b/azure-iot-device/samples/sync-samples/receive_message.py similarity index 100% rename from azure-iot-device/samples/legacy-samples/receive_message.py rename to azure-iot-device/samples/sync-samples/receive_message.py diff --git a/azure-iot-device/samples/legacy-samples/receive_message_on_input.py b/azure-iot-device/samples/sync-samples/receive_message_on_input.py similarity index 100% rename from azure-iot-device/samples/legacy-samples/receive_message_on_input.py rename to azure-iot-device/samples/sync-samples/receive_message_on_input.py diff --git a/azure-iot-device/samples/legacy-samples/receive_message_x509.py b/azure-iot-device/samples/sync-samples/receive_message_x509.py similarity index 100% rename from azure-iot-device/samples/legacy-samples/receive_message_x509.py rename to azure-iot-device/samples/sync-samples/receive_message_x509.py diff --git a/azure-iot-device/samples/legacy-samples/receive_twin_desired_properties_patch.py b/azure-iot-device/samples/sync-samples/receive_twin_desired_properties_patch.py similarity index 100% rename from azure-iot-device/samples/legacy-samples/receive_twin_desired_properties_patch.py rename to azure-iot-device/samples/sync-samples/receive_twin_desired_properties_patch.py diff --git a/azure-iot-device/samples/legacy-samples/send_message.py b/azure-iot-device/samples/sync-samples/send_message.py similarity index 100% rename from azure-iot-device/samples/legacy-samples/send_message.py rename to azure-iot-device/samples/sync-samples/send_message.py diff --git a/azure-iot-device/samples/legacy-samples/send_message_to_output.py b/azure-iot-device/samples/sync-samples/send_message_to_output.py similarity index 100% rename from azure-iot-device/samples/legacy-samples/send_message_to_output.py rename to azure-iot-device/samples/sync-samples/send_message_to_output.py diff --git a/azure-iot-device/samples/legacy-samples/send_message_via_module_x509.py b/azure-iot-device/samples/sync-samples/send_message_via_module_x509.py similarity index 100% rename from azure-iot-device/samples/legacy-samples/send_message_via_module_x509.py rename to azure-iot-device/samples/sync-samples/send_message_via_module_x509.py diff --git a/azure-iot-device/samples/sync-samples/send_message_via_proxy.py b/azure-iot-device/samples/sync-samples/send_message_via_proxy.py new file mode 100644 index 000000000..9bde1fb20 --- /dev/null +++ b/azure-iot-device/samples/sync-samples/send_message_via_proxy.py @@ -0,0 +1,75 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import time +import uuid +from azure.iot.device import IoTHubDeviceClient, Message, ProxyOptions +import socks +import logging + +logging.basicConfig(level=logging.DEBUG) + +# The connection string for a device should never be stored in code. For the sake of simplicity we're using an environment variable here. +conn_str = os.getenv("IOTHUB_DEVICE_CONNECTION_STRING") + +# Create proxy options when trying to send via proxy +proxy_opts = ProxyOptions( + proxy_type=socks.HTTP, proxy_addr="127.0.0.1", proxy_port=8888 # localhost +) +# The client object is used to interact with your Azure IoT hub. +device_client = IoTHubDeviceClient.create_from_connection_string( + conn_str, websockets=True, proxy_options=proxy_opts +) + +# Connect the client. +device_client.connect() + +# send 2 messages with 2 system properties & 1 custom property with a 1 second pause between each message +for i in range(1, 3): + print("sending message #" + str(i)) + msg = Message("test wind speed " + str(i)) + msg.message_id = uuid.uuid4() + msg.correlation_id = "correlation-1234" + msg.custom_properties["tornado-warning"] = "yes" + device_client.send_message(msg) + time.sleep(1) + +# send 2 messages with only custom property with a 1 second pause between each message +for i in range(3, 5): + print("sending message #" + str(i)) + msg = Message("test wind speed " + str(i)) + msg.custom_properties["tornado-warning"] = "yes" + device_client.send_message(msg) + time.sleep(1) + +# send 2 messages with only system properties with a 1 second pause between each message +for i in range(5, 7): + print("sending message #" + str(i)) + msg = Message("test wind speed " + str(i)) + msg.message_id = uuid.uuid4() + msg.correlation_id = "correlation-1234" + device_client.send_message(msg) + time.sleep(1) + +# send 2 messages with 1 system property and 1 custom property with a 1 second pause between each message +for i in range(7, 9): + print("sending message #" + str(i)) + msg = Message("test wind speed " + str(i)) + msg.message_id = uuid.uuid4() + msg.custom_properties["tornado-warning"] = "yes" + device_client.send_message(msg) + time.sleep(1) + +# send only string messages +for i in range(9, 11): + print("sending message #" + str(i)) + device_client.send_message("test payload message " + str(i)) + time.sleep(1) + + +# finally, disconnect +device_client.disconnect() diff --git a/azure-iot-device/samples/legacy-samples/send_message_x509.py b/azure-iot-device/samples/sync-samples/send_message_x509.py similarity index 100% rename from azure-iot-device/samples/legacy-samples/send_message_x509.py rename to azure-iot-device/samples/sync-samples/send_message_x509.py diff --git a/azure-iot-device/samples/legacy-samples/update_twin_reported_properties.py b/azure-iot-device/samples/sync-samples/update_twin_reported_properties.py similarity index 100% rename from azure-iot-device/samples/legacy-samples/update_twin_reported_properties.py rename to azure-iot-device/samples/sync-samples/update_twin_reported_properties.py diff --git a/azure-iot-device/setup.py b/azure-iot-device/setup.py index de66d629a..e196e3190 100644 --- a/azure-iot-device/setup.py +++ b/azure-iot-device/setup.py @@ -36,7 +36,7 @@ setup( version=constant["VERSION"], description="Microsoft Azure IoT Device Library", license="MIT License", - url="https://github.com/Azure/azure-iot-sdk-python-preview", + url="https://github.com/Azure/azure-iot-sdk-python/tree/master/azure-iot-device", author="Microsoft Corporation", author_email="opensource@microsoft.com", long_description=_long_description, @@ -54,6 +54,7 @@ setup( "Programming Language :: Python :: 3.5", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", ], install_requires=[ # Define sub-dependencies due to pip dependency resolution bug @@ -71,6 +72,8 @@ setup( "requests-unixsocket>=0.1.5,<1.0.0", "janus>=0.4.0,<1.0.0;python_version>='3.5'", "futures;python_version == '2.7'", + "PySocks", + "win-inet-pton;python_version == '2.7'", ], extras_require={":python_version<'3.0'": ["azure-iot-nspkg>=1.0.1"]}, python_requires=">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3*, <4", diff --git a/azure-iot-device/tests/common/conftest.py b/azure-iot-device/tests/common/conftest.py index c4414c3dd..ae2fe0f3d 100644 --- a/azure-iot-device/tests/common/conftest.py +++ b/azure-iot-device/tests/common/conftest.py @@ -15,11 +15,6 @@ if sys.version_info < (3, 5): collect_ignore.append("test_asyncio_compat.py") -@pytest.fixture -def fake_error(): - return RuntimeError("__fake_error__") - - @pytest.fixture def fake_return_arg_value(): return "__fake_return_arg_value__" diff --git a/azure-iot-device/tests/common/pipeline/conftest.py b/azure-iot-device/tests/common/pipeline/conftest.py index fe75531f5..83e996b1c 100644 --- a/azure-iot-device/tests/common/pipeline/conftest.py +++ b/azure-iot-device/tests/common/pipeline/conftest.py @@ -5,15 +5,8 @@ # -------------------------------------------------------------------------- from tests.common.pipeline.fixtures import ( - callback, - fake_exception, - fake_base_exception, - event, - op, - op2, - op3, - finally_op, - new_op, + arbitrary_event, + arbitrary_op, fake_pipeline_thread, fake_non_pipeline_thread, unhandled_error_handler, diff --git a/azure-iot-device/tests/common/pipeline/fixtures.py b/azure-iot-device/tests/common/pipeline/fixtures.py index 30d821e3a..56a7886b9 100644 --- a/azure-iot-device/tests/common/pipeline/fixtures.py +++ b/azure-iot-device/tests/common/pipeline/fixtures.py @@ -6,7 +6,7 @@ import pytest import threading from tests.common.pipeline import helpers -from azure.iot.device.common import unhandled_exceptions +from azure.iot.device.common import handle_exceptions from azure.iot.device.common.pipeline import ( pipeline_events_base, pipeline_ops_base, @@ -14,75 +14,25 @@ from azure.iot.device.common.pipeline import ( ) -# TODO: remove this fixture -# Using it is dangerous if multiple ops use this same callback -# within the scope of the same test -# (i.e. an op under test, and an op run in test setup) -# What happens is that those operations all are tied together with this -# same callback mock, and the same callback is called multiple times -# leading to unexpected behavior -@pytest.fixture -def callback(mocker): - return mocker.MagicMock() - - -@pytest.fixture -def fake_exception(): - return Exception() - - -@pytest.fixture -def fake_base_exception(): - return helpers.UnhandledException() - - -class FakeEvent(pipeline_events_base.PipelineEvent): +class ArbitraryEvent(pipeline_events_base.PipelineEvent): def __init__(self): - super(FakeEvent, self).__init__() + super(ArbitraryEvent, self).__init__() @pytest.fixture -def event(): - return FakeEvent() +def arbitrary_event(): + return ArbitraryEvent() -class FakeOperation(pipeline_ops_base.PipelineOperation): +class ArbitraryOperation(pipeline_ops_base.PipelineOperation): def __init__(self, callback=None): - super(FakeOperation, self).__init__(callback=callback) + super(ArbitraryOperation, self).__init__(callback=callback) @pytest.fixture -def op(callback): - op = FakeOperation(callback=callback) - op.name = "op" - return op - - -@pytest.fixture -def op2(callback): - op = FakeOperation(callback=callback) - op.name = "op2" - return op - - -@pytest.fixture -def op3(callback): - op = FakeOperation(callback=callback) - op.name = "op3" - return op - - -@pytest.fixture -def finally_op(callback): - op = FakeOperation(callback=callback) - op.name = "finally_op" - return op - - -@pytest.fixture -def new_op(callback): - op = FakeOperation(callback=callback) - op.name = "new_op" +def arbitrary_op(mocker): + op = ArbitraryOperation(callback=mocker.MagicMock()) + mocker.spy(op, "complete") return op @@ -115,4 +65,4 @@ def fake_non_pipeline_thread(): @pytest.fixture def unhandled_error_handler(mocker): - return mocker.patch.object(unhandled_exceptions, "exception_caught_in_background_thread") + return mocker.patch.object(handle_exceptions, "handle_background_exception") diff --git a/azure-iot-device/tests/common/pipeline/helpers.py b/azure-iot-device/tests/common/pipeline/helpers.py index 0fc926689..c5c0ab818 100644 --- a/azure-iot-device/tests/common/pipeline/helpers.py +++ b/azure-iot-device/tests/common/pipeline/helpers.py @@ -7,13 +7,14 @@ import inspect import pytest import functools from threading import Event +from azure.iot.device.common import handle_exceptions from azure.iot.device.common.pipeline import ( pipeline_events_base, pipeline_ops_base, pipeline_stages_base, pipeline_events_mqtt, pipeline_ops_mqtt, - operation_flow, + config, ) try: @@ -21,15 +22,79 @@ try: except ImportError: from inspect import getargspec + +class StageRunOpTestBase(object): + """All PipelineStage .run_op() tests should inherit from this base class. + It provides basic tests for dealing with exceptions. + """ + + @pytest.mark.it( + "Completes the operation with failure if an unexpected Exception is raised while executing the operation" + ) + def test_completes_operation_with_error(self, mocker, stage, op, arbitrary_exception): + stage._run_op = mocker.MagicMock(side_effect=arbitrary_exception) + # mocker.spy(op, "complete") + + stage.run_op(op) + + assert op.completed + assert op.error is arbitrary_exception + # assert op.complete.call_count == 1 + # assert op.complete.call_args == mocker.call(error=arbitrary_exception) + + @pytest.mark.it( + "Allows any BaseException that was raised during execution of the operation to propogate" + ) + def test_base_exception_propogates(self, mocker, stage, op, arbitrary_base_exception): + stage._run_op = mocker.MagicMock(side_effect=arbitrary_base_exception) + + with pytest.raises(arbitrary_base_exception.__class__) as e_info: + stage.run_op(op) + assert e_info.value is arbitrary_base_exception + + +class StageHandlePipelineEventTestBase(object): + """All PipelineStage .handle_pipeline_event() tests should inherit from this base class. + It provides basic tests for dealing with exceptions. + """ + + @pytest.mark.it( + "Sends any unexpected Exceptions raised during handling of the event to the background exception handler" + ) + def test_uses_background_exception_handler(self, mocker, stage, event, arbitrary_exception): + stage._handle_pipeline_event = mocker.MagicMock(side_effect=arbitrary_exception) + mocker.spy(handle_exceptions, "handle_background_exception") + + stage.handle_pipeline_event(event) + + assert handle_exceptions.handle_background_exception.call_count == 1 + assert handle_exceptions.handle_background_exception.call_args == mocker.call( + arbitrary_exception + ) + + @pytest.mark.it("Allows any BaseException raised during handling of the event to propogate") + def test_base_exception_propogates(self, mocker, stage, event, arbitrary_base_exception): + stage._handle_pipeline_event = mocker.MagicMock(side_effect=arbitrary_base_exception) + + with pytest.raises(arbitrary_base_exception.__class__) as e_info: + stage.handle_pipeline_event(event) + assert e_info.value is arbitrary_base_exception + + +############################################ +# EVERYTHING BELOW THIS POINT IS DEPRECATED# +############################################ +# CT-TODO: remove + all_common_ops = [ pipeline_ops_base.ConnectOperation, - pipeline_ops_base.ReconnectOperation, + pipeline_ops_base.ReauthorizeConnectionOperation, pipeline_ops_base.DisconnectOperation, pipeline_ops_base.EnableFeatureOperation, pipeline_ops_base.DisableFeatureOperation, pipeline_ops_base.UpdateSasTokenOperation, - pipeline_ops_base.SendIotRequestAndWaitForResponseOperation, - pipeline_ops_base.SendIotRequestOperation, + pipeline_ops_base.RequestAndResponseOperation, + pipeline_ops_base.RequestOperation, pipeline_ops_mqtt.SetMQTTConnectionArgsOperation, pipeline_ops_mqtt.MQTTPublishOperation, pipeline_ops_mqtt.MQTTSubscribeOperation, @@ -50,50 +115,56 @@ def all_except(all_items, items_to_exclude): return [x for x in all_items if x not in items_to_exclude] -def make_mock_stage(mocker, stage_to_make): - """ - make a stage object that we can use in testing. This stage object is popsulated - by mocker spies, and it has a next stage that can receive events. It does not, - by detfault, have a previous stage or a pipeline root that can receive events - coming back up. The previous stage is added by the tests which which require it. - """ - # because PipelineStage is abstract, we need something concrete - class NextStageForTest(pipeline_stages_base.PipelineStage): - def _execute_op(self, op): - operation_flow.pass_op_to_next_stage(self, op) +class StageTestBase(object): + @pytest.fixture(autouse=True) + def stage_base_configuration(self, stage, mocker): + """ + This fixture configures the stage for testing. This is automatically + applied, so it will be called before your test runs, but it's not + guaranteed to be called before any other fixtures run. If you have + a fixture that needs to rely on the stage being configured, then + you have to add a manual dependency inside that fixture (like we do in + next_stage_succeeds_all_ops below) + """ - def stage_execute_op(self, op): - if getattr(op, "action", None) is None or op.action == "pass": - operation_flow.complete_op(self, op) - elif op.action == "fail" or op.action == "exception": - raise Exception() - elif op.action == "base_exception": - raise UnhandledException() - elif op.action == "pend": - pass - else: - assert False + class NextStageForTest(pipeline_stages_base.PipelineStage): + def _run_op(self, op): + pass - first_stage = stage_to_make() - first_stage.unhandled_error_handler = mocker.Mock() - mocker.spy(first_stage, "_execute_op") - mocker.spy(first_stage, "run_op") + next = NextStageForTest() + root = ( + pipeline_stages_base.PipelineRootStage(config.BasePipelineConfig()) + .append_stage(stage) + .append_stage(next) + ) - next_stage = NextStageForTest() - next_stage._execute_op = functools.partial(stage_execute_op, next_stage) - mocker.spy(next_stage, "_execute_op") - mocker.spy(next_stage, "run_op") + mocker.spy(stage, "_run_op") + mocker.spy(stage, "run_op") - first_stage.next = next_stage - # TODO: this is sloppy. we should have a real root here for testing. - first_stage.pipeline_root = first_stage + mocker.spy(next, "_run_op") + mocker.spy(next, "run_op") - next_stage.previous = first_stage - next_stage.pipeline_root = first_stage + return root - first_stage.pipeline_root.connected = False + @pytest.fixture + def next_stage_succeeds(self, stage, stage_base_configuration, mocker): + def complete_op_success(op): + op.complete() - return first_stage + stage.next._run_op = complete_op_success + mocker.spy(stage.next, "_run_op") + + @pytest.fixture + def next_stage_raises_arbitrary_exception( + self, stage, stage_base_configuration, mocker, arbitrary_exception + ): + stage.next._run_op = mocker.MagicMock(side_effect=arbitrary_exception) + + @pytest.fixture + def next_stage_raises_arbitrary_base_exception( + self, stage, stage_base_configuration, mocker, arbitrary_base_exception + ): + stage.next._run_op = mocker.MagicMock(side_effect=arbitrary_base_exception) def assert_callback_succeeded(op, callback=None): @@ -107,9 +178,10 @@ def assert_callback_succeeded(op, callback=None): except AttributeError: pass assert callback.call_count == 1 - callback_arg = callback.call_args[0][0] - assert callback_arg == op - assert op.error is None + callback_op_arg = callback.call_args[0][0] + assert callback_op_arg == op + callback_error_arg = callback.call_args[1]["error"] + assert callback_error_arg is None def assert_callback_failed(op, callback=None, error=None): @@ -123,20 +195,17 @@ def assert_callback_failed(op, callback=None, error=None): except AttributeError: pass assert callback.call_count == 1 - callback_arg = callback.call_args[0][0] - assert callback_arg == op + callback_op_arg = callback.call_args[0][0] + assert callback_op_arg == op + callback_error_arg = callback.call_args[1]["error"] if error: if isinstance(error, type): - assert isinstance(op.error, error) + assert callback_error_arg.__class__ == error else: - assert op.error is error + assert callback_error_arg is error else: - assert op.error is not None - - -class UnhandledException(BaseException): - pass + assert callback_error_arg is not None def get_arg_count(fn): diff --git a/azure-iot-device/tests/common/pipeline/pipeline_config_test.py b/azure-iot-device/tests/common/pipeline/pipeline_config_test.py new file mode 100644 index 000000000..739fd6f4b --- /dev/null +++ b/azure-iot-device/tests/common/pipeline/pipeline_config_test.py @@ -0,0 +1,92 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import pytest + + +class PipelineConfigInstantiationTestBase(object): + """All PipelineConfig instantiation tests should inherit from this base class. + It provides tests for shared functionality among all PipelineConfigs, derived from + the BasePipelineConfig class. + """ + + @pytest.mark.it( + "Instantiates with the 'websockets' attribute set to the provided 'websockets' parameter" + ) + @pytest.mark.parametrize( + "websockets", [True, False], ids=["websockets == True", "websockets == False"] + ) + def test_websockets_set(self, config_cls, websockets): + config = config_cls(websockets=websockets) + assert config.websockets is websockets + + @pytest.mark.it( + "Instantiates with the 'cipher' attribute set to OpenSSL list formatted version of the provided 'cipher' parameter" + ) + @pytest.mark.parametrize( + "cipher_input, expected_cipher", + [ + pytest.param( + "DHE-RSA-AES128-SHA", + "DHE-RSA-AES128-SHA", + id="Single cipher suite, OpenSSL list formatted string", + ), + pytest.param( + "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256", + "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256", + id="Multiple cipher suites, OpenSSL list formatted string", + ), + pytest.param( + "DHE_RSA_AES128_SHA", + "DHE-RSA-AES128-SHA", + id="Single cipher suite, as string with '_' delimited algorithms/protocols", + ), + pytest.param( + "DHE_RSA_AES128_SHA:DHE_RSA_AES256_SHA:ECDHE_ECDSA_AES128_GCM_SHA256", + "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256", + id="Multiple cipher suites, as string with '_' delimited algorithms/protocols and ':' delimited suites", + ), + pytest.param( + ["DHE-RSA-AES128-SHA"], + "DHE-RSA-AES128-SHA", + id="Single cipher suite, in a list, with '-' delimited algorithms/protocols", + ), + pytest.param( + ["DHE-RSA-AES128-SHA", "DHE-RSA-AES256-SHA", "ECDHE-ECDSA-AES128-GCM-SHA256"], + "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256", + id="Multiple cipher suites, in a list, with '-' delimited algorithms/protocols", + ), + pytest.param( + ["DHE_RSA_AES128_SHA"], + "DHE-RSA-AES128-SHA", + id="Single cipher suite, in a list, with '_' delimited algorithms/protocols", + ), + pytest.param( + ["DHE_RSA_AES128_SHA", "DHE_RSA_AES256_SHA", "ECDHE_ECDSA_AES128_GCM_SHA256"], + "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256", + id="Multiple cipher suites, in a list, with '_' delimited algorithms/protocols", + ), + ], + ) + def test_cipher(self, config_cls, cipher_input, expected_cipher): + config = config_cls(cipher=cipher_input) + assert config.cipher == expected_cipher + + @pytest.mark.it( + "Raises TypeError if the provided 'cipher' attribute is neither list nor string" + ) + @pytest.mark.parametrize( + "cipher", + [ + pytest.param(123, id="int"), + pytest.param( + {"cipher1": "DHE-RSA-AES128-SHA", "cipher2": "DHE_RSA_AES256_SHA"}, id="dict" + ), + pytest.param(object(), id="complex object"), + ], + ) + def test_invalid_cipher_param(self, config_cls, cipher): + with pytest.raises(TypeError): + config_cls(cipher=cipher) diff --git a/azure-iot-device/tests/common/pipeline/pipeline_data_object_test.py b/azure-iot-device/tests/common/pipeline/pipeline_event_test.py similarity index 84% rename from azure-iot-device/tests/common/pipeline/pipeline_data_object_test.py rename to azure-iot-device/tests/common/pipeline/pipeline_event_test.py index cf9b9c5f5..3c28713a1 100644 --- a/azure-iot-device/tests/common/pipeline/pipeline_data_object_test.py +++ b/azure-iot-device/tests/common/pipeline/pipeline_event_test.py @@ -8,6 +8,8 @@ import inspect fake_count = 0 +# CT-TODO: refactor this module + def get_next_fake_value(): """ @@ -21,31 +23,9 @@ def get_next_fake_value(): return "__fake_value_{}__".format(fake_count) -base_operation_defaults = {"needs_connection": False, "error": None} base_event_defaults = {} -def add_operation_test( - cls, module, extra_defaults={}, positional_arguments=[], keyword_arguments={} -): - """ - Add a test class to test the given PipelineOperation class. The class that - we're testing is passed in the cls parameter, and the different initialization - constants are passed with the named arguments that follow. - """ - all_extra_defaults = extra_defaults.copy() - all_extra_defaults.update(name=cls.__name__) - - add_instantiation_test( - cls=cls, - module=module, - defaults=base_operation_defaults, - extra_defaults=all_extra_defaults, - positional_arguments=positional_arguments, - keyword_arguments=keyword_arguments, - ) - - def add_event_test(cls, module, extra_defaults={}, positional_arguments=[], keyword_arguments={}): """ Add a test class to test the given PipelineOperation class. The class that @@ -74,7 +54,7 @@ def add_instantiation_test( """ # `defaults` contains an array of object attributes that should be set when - # we call the initializer will all of the required positional arguments + # we call the initializer with all of the required positional arguments # and none of the optional keyword arguments. all_defaults = defaults.copy() diff --git a/azure-iot-device/tests/common/pipeline/pipeline_ops_test.py b/azure-iot-device/tests/common/pipeline/pipeline_ops_test.py new file mode 100644 index 000000000..4c676fafd --- /dev/null +++ b/azure-iot-device/tests/common/pipeline/pipeline_ops_test.py @@ -0,0 +1,782 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import pytest +import logging +import threading + +from azure.iot.device.common.pipeline.pipeline_ops_base import PipelineOperation +from azure.iot.device.common import handle_exceptions +from azure.iot.device.common.pipeline import pipeline_exceptions + +logging.basicConfig(level=logging.DEBUG) + + +def add_operation_tests( + test_module, + op_class_under_test, + op_test_config_class, + extended_op_instantiation_test_class=None, +): + """ + Add shared tests for an Operation class to a testing module. + These tests need to be done for every Operation class. + + :param test_module: A reference to the test module to add tests to + :param op_class_under_test: A reference to the specific Operation class under test + :param op_test_config_class: A class providing fixtures specific to the Operation class + under test. This class must define the following fixtures: + - "cls_type" (which returns a reference to the Operation class under test) + - "init_kwargs" (which returns a dictionary of kwargs and associated values used to + instantiate the class) + :param extended_op_instantiation_test_class: A class defining instantiation tests that are + specific to the Operation class under test, and not shared with all Operations. + Note that you may override shared instantiation tests defined in this function within + the provided test class (e.g. test_needs_connection) + """ + + # Extend the provided test config class + class OperationTestConfigClass(op_test_config_class): + @pytest.fixture + def op(self, cls_type, init_kwargs, mocker): + op = cls_type(**init_kwargs) + mocker.spy(op, "complete") + return op + + @pytest.mark.describe("{} - Instantiation".format(op_class_under_test.__name__)) + class OperationBaseInstantiationTests(OperationTestConfigClass): + @pytest.mark.it("Initializes 'name' attribute as the classname") + def test_name(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.name == op.__class__.__name__ + + @pytest.mark.it("Initializes 'completed' attribute as False") + def test_completed(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.completed is False + + @pytest.mark.it("Initializes 'completing' attribute as False") + def test_completing(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.completing is False + + @pytest.mark.it("Initializes 'error' attribute as None") + def test_error(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.error is None + + # NOTE: this test should be overridden for operations that set this value to True + @pytest.mark.it("Initializes 'needs_connection' attribute as False") + def test_needs_connection(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.needs_connection is False + + @pytest.mark.it("Initializes 'callback_stack' list attribute with the provided callback") + def test_callback_added_to_list(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert len(op.callback_stack) == 1 + assert op.callback_stack[0] is init_kwargs["callback"] + + # If an extended operation instantiation test class is provided, use those tests as well. + # By using the extended_op_instantation_test_class as the first parent class, this ensures that + # tests from OperationBaseInstantiationTests (e.g. test_needs_connection) can be overwritten by + # tests provided in extended_op_instantiation_test_class. + if extended_op_instantiation_test_class: + + class OperationInstantiationTests( + extended_op_instantiation_test_class, OperationBaseInstantiationTests + ): + pass + + else: + + class OperationInstantiationTests(OperationBaseInstantiationTests): + pass + + @pytest.mark.describe("{} - .add_callback()".format(op_class_under_test.__name__)) + class OperationAddCallbackTests(OperationTestConfigClass): + @pytest.fixture( + params=["Currently completing with no error", "Currently completing with error"] + ) + def error(self, request, arbitrary_exception): + if request.param == "Currently completing with no error": + return None + else: + return arbitrary_exception + + @pytest.mark.it("Adds a callback to the operation's callback stack'") + def test_adds_callback(self, mocker, op): + # Because op was instantiated with a callback, because 'callback' is a + # required parameter, there will already be one callback on the stack + # before we add additional ones. + assert len(op.callback_stack) == 1 + cb1 = mocker.MagicMock() + op.add_callback(cb1) + assert len(op.callback_stack) == 2 + assert op.callback_stack[1] == cb1 + + cb2 = mocker.MagicMock() + op.add_callback(cb2) + assert len(op.callback_stack) == 3 + assert op.callback_stack[1] == cb1 + assert op.callback_stack[2] == cb2 + + @pytest.mark.it( + "Raises an OperationError if attempting to add a callback to an already-completed operation" + ) + def test_already_completed_callback(self, mocker, op): + op.complete() + assert op.completed + + with pytest.raises(pipeline_exceptions.OperationError): + op.add_callback(mocker.MagicMock()) + + @pytest.mark.it( + "Raises an OperationError if attempting to add a callback to an operation that is currently undergoing the completion process" + ) + def test_currently_completing(self, mocker, op, error): + mocker.spy(handle_exceptions, "handle_background_exception") + + def cb(op, error): + with pytest.raises(pipeline_exceptions.OperationError): + # Add a callback during completion of the callback, i.e. while op completion is in progress + op.add_callback(mocker.MagicMock()) + + mock_cb = mocker.MagicMock(side_effect=cb) + op.add_callback(mock_cb) + + op.complete(error) + + assert mock_cb.call_count == 1 + + @pytest.mark.describe("{} - .spawn_worker_op()".format(op_class_under_test.__name__)) + class OperationSpawnWorkerOpTests(OperationTestConfigClass): + @pytest.fixture + def worker_op_type(self): + class SomeOperationType(PipelineOperation): + def __init__(self, arg1, arg2, arg3, callback): + super(SomeOperationType, self).__init__(callback=callback) + + return SomeOperationType + + @pytest.fixture + def worker_op_kwargs(self): + kwargs = {"arg1": 1, "arg2": 2, "arg3": 3} + return kwargs + + @pytest.mark.it( + "Creates and returns an new instance of the Operation class specified in the 'worker_op_type' parameter" + ) + def test_returns_worker_op_instance(self, op, worker_op_type, worker_op_kwargs): + worker_op = op.spawn_worker_op(worker_op_type, **worker_op_kwargs) + assert isinstance(worker_op, worker_op_type) + + @pytest.mark.it( + "Instantiates the returned worker operation using the provided **kwargs parameters (not including 'callback')" + ) + def test_creates_worker_op_with_provided_kwargs(self, mocker, op, worker_op_kwargs): + mock_instance = mocker.MagicMock() + mock_type = mocker.MagicMock(return_value=mock_instance) + mock_type.__name__ = "mock type" # this is needed for log statements + assert "callback" not in worker_op_kwargs + + worker_op = op.spawn_worker_op(mock_type, **worker_op_kwargs) + + assert worker_op is mock_instance + assert mock_type.call_count == 1 + + # Show that all provided kwargs are used. Note that this test does NOT show that + # ONLY the provided kwargs are used - because there ARE additional kwargs added. + for kwarg in worker_op_kwargs: + assert mock_type.call_args[1][kwarg] == worker_op_kwargs[kwarg] + + @pytest.mark.it( + "Adds a secondary callback to the worker operation after instantiation, if 'callback' is included in the provided **kwargs parameters" + ) + def test_adds_callback_to_worker_op(self, mocker, op, worker_op_kwargs): + mock_instance = mocker.MagicMock() + mock_type = mocker.MagicMock(return_value=mock_instance) + mock_type.__name__ = "mock type" # this is needed for log statements + worker_op_kwargs["callback"] = mocker.MagicMock() + + worker_op = op.spawn_worker_op(mock_type, **worker_op_kwargs) + + assert worker_op is mock_instance + assert mock_type.call_count == 1 + + # The callback used for instantiating the worker operation is NOT the callback provided in **kwargs + assert mock_type.call_args[1]["callback"] is not worker_op_kwargs["callback"] + + # The callback provided in **kwargs is applied after instantiation + assert mock_instance.add_callback.call_count == 1 + assert mock_instance.add_callback.call_args == mocker.call(worker_op_kwargs["callback"]) + + @pytest.mark.it( + "Raises TypeError if the provided **kwargs parameters do not match the constructor for the class provided in the 'worker_op_type' parameter" + ) + def test_incorrect_kwargs(self, mocker, op, worker_op_type, worker_op_kwargs): + worker_op_kwargs["invalid_kwarg"] = "some value" + + with pytest.raises(TypeError): + op.spawn_worker_op(worker_op_type, **worker_op_kwargs) + + @pytest.mark.it( + "Returns a worker operation, which, when completed, completes the operation that spawned it with the same error status" + ) + @pytest.mark.parametrize( + "use_error", [pytest.param(False, id="No Error"), pytest.param(True, id="With Error")] + ) + def test_worker_op_completes_original_op( + self, mocker, use_error, arbitrary_exception, op, worker_op_type, worker_op_kwargs + ): + original_op = op + + if use_error: + error = arbitrary_exception + else: + error = None + + worker_op = original_op.spawn_worker_op(worker_op_type, **worker_op_kwargs) + assert not original_op.completed + + worker_op.complete(error=error) + + # Worker op has been completed with the given error state + assert worker_op.completed + assert worker_op.error is error + + # Original op is now completed with the same given error state + assert original_op.completed + assert original_op.error is error + + @pytest.mark.it( + "Returns a worker operation, which, when completed, triggers the 'callback' optionally provided in the **kwargs parameter, prior to completing the operation that spawned it" + ) + @pytest.mark.parametrize( + "use_error", [pytest.param(False, id="No Error"), pytest.param(True, id="With Error")] + ) + def test_worker_op_triggers_own_callback_and_then_completes_original_op( + self, mocker, use_error, arbitrary_exception, op, worker_op_type, worker_op_kwargs + ): + mocker.spy(handle_exceptions, "handle_background_exception") + + original_op = op + + def callback(op, error): + # Assert this callback is called before the original op begins the completion process + assert not original_op.completed + assert original_op.complete.call_count == 0 + + cb_mock = mocker.MagicMock(side_effect=callback) + + worker_op_kwargs["callback"] = cb_mock + + if use_error: + error = arbitrary_exception + else: + error = None + + worker_op = original_op.spawn_worker_op(worker_op_type, **worker_op_kwargs) + assert original_op.complete.call_count == 0 + + worker_op.complete(error=error) + + # Provided callback was called + assert cb_mock.call_count == 1 + assert cb_mock.call_args == mocker.call(op=worker_op, error=error) + + # Worker op was completed + assert worker_op.completed + + # The original op that spawned the worker is also completed + assert original_op.completed + assert original_op.complete.call_count == 1 + assert original_op.complete.call_args == mocker.call(error=error) + + # Because exceptions raised in callbacks are caught and sent to the background exception handler, + # the assertions in the above callback won't be able to directly raise AssertionErrors that will + # allow for testing normally. Instead we should check the background_exception_handler to see if + # any of the assertions raised errors and sent them there. + assert handle_exceptions.handle_background_exception.call_count == 0 + + @pytest.mark.describe("{} - .complete()".format(op_class_under_test.__name__)) + class OperationCompleteTests(OperationTestConfigClass): + @pytest.fixture(params=["Successful completion", "Completion with error"]) + def error(self, request, arbitrary_exception): + if request.param == "Successful completion": + return None + else: + return arbitrary_exception + + @pytest.mark.it( + "Triggers and removes callbacks from the operation's callback stack according to LIFO order, passing the operation and any error to each callback" + ) + def test_trigger_callbacks(self, mocker, cls_type, init_kwargs, error): + mocker.spy(handle_exceptions, "handle_background_exception") + + # Set up callback mocks + cb1_mock = mocker.MagicMock() + cb2_mock = mocker.MagicMock() + cb3_mock = mocker.MagicMock() + + def cb1(op, error): + # All callbacks have been triggered + assert cb1_mock.call_count == 1 + assert cb2_mock.call_count == 1 + assert cb3_mock.call_count == 1 + assert len(op.callback_stack) == 0 + + def cb2(op, error): + # Callback 3 and Callback 2 have been triggered, but Callback 1 has not + assert cb1_mock.call_count == 0 + assert cb2_mock.call_count == 1 + assert cb3_mock.call_count == 1 + assert len(op.callback_stack) == 1 + + def cb3(op, error): + # Callback 3 has been triggered, but no others have been. + assert cb1_mock.call_count == 0 + assert cb2_mock.call_count == 0 + assert cb3_mock.call_count == 1 + assert len(op.callback_stack) == 2 + + cb1_mock.side_effect = cb1 + cb2_mock.side_effect = cb2 + cb3_mock.side_effect = cb3 + + # Attach callbacks to op + init_kwargs["callback"] = cb1_mock + op = cls_type(**init_kwargs) + op.add_callback(cb2_mock) + op.add_callback(cb3_mock) + assert len(op.callback_stack) == 3 + assert not op.completed + + # Run the completion + op.complete(error=error) + + assert op.completed + assert cb3_mock.call_count == 1 + assert cb3_mock.call_args == mocker.call(op=op, error=error) + assert cb2_mock.call_count == 1 + assert cb2_mock.call_args == mocker.call(op=op, error=error) + assert cb1_mock.call_count == 1 + assert cb1_mock.call_args == mocker.call(op=op, error=error) + + # Because exceptions raised in callbacks are caught and sent to the background exception handler, + # the assertions in the above callbacks won't be able to directly raise AssertionErrors that will + # allow for testing normally. Instead we should check the background_exception_handler to see if + # any of the assertions raised errors and sent them there. + assert handle_exceptions.handle_background_exception.call_count == 0 + + @pytest.mark.it( + "Sets the 'error' attribute to the specified error (if any) at the beginning of the completion process" + ) + def test_sets_error(self, mocker, op, error): + mocker.spy(handle_exceptions, "handle_background_exception") + original_err = error + + def cb(op, error): + # During the completion process, the 'error' attribute has been set + assert op.error is original_err + assert error is original_err + + cb_mock = mocker.MagicMock(side_effect=cb) + op.add_callback(cb_mock) + + op.complete(error=error) + + # Callback was triggered during completion + assert cb_mock.call_count == 1 + + # After the completion process, the 'error' attribute is still set + assert op.error is error + + # Because exceptions raised in callbacks are caught and sent to the background exception handler, + # the assertion in the above callback won't be able to directly raise AssertionErrors that will + # allow for testing normally. Instead we should check the background_exception_handler to see if + # any of the assertions raised errors and sent them there. + assert handle_exceptions.handle_background_exception.call_count == 0 + + @pytest.mark.it( + "Sets the 'completing' attribute to True only for the duration of the completion process" + ) + def test_completing_set(self, mocker, op, error): + mocker.spy(handle_exceptions, "handle_background_exception") + + def cb(op, error): + # The operation is completing, but not completed + assert op.completing + assert not op.completed + + cb_mock = mocker.MagicMock(side_effect=cb) + op.add_callback(cb_mock) + + op.complete(error) + + # Callback was called + assert cb_mock.call_count == 1 + + # Once completed, the op is no longer completing + assert not op.completing + assert op.completed + + # Because exceptions raised in callbacks are caught and sent to the background exception handler, + # the assertion in the above callback won't be able to directly raise AssertionErrors that will + # allow for testing normally. Instead we should check the background_exception_handler to see if + # any of the assertions raised errors and sent them there. + assert handle_exceptions.handle_background_exception.call_count == 0 + + @pytest.mark.it( + "Handles any Exceptions raised during execution of a callback by sending them to the background exception handler, and continuing on with completion" + ) + def test_callback_raises_error( + self, mocker, arbitrary_exception, cls_type, init_kwargs, error + ): + mocker.spy(handle_exceptions, "handle_background_exception") + + # Set up callback mocks + cb1_mock = mocker.MagicMock() + cb2_mock = mocker.MagicMock(side_effect=arbitrary_exception) + cb3_mock = mocker.MagicMock() + + # Attach callbacks to op + init_kwargs["callback"] = cb1_mock + op = cls_type(**init_kwargs) + op.add_callback(cb2_mock) + op.add_callback(cb3_mock) + assert len(op.callback_stack) == 3 + assert not op.completed + + # Run the completion + op.complete(error=error) + + # Op was completed, and all callbacks triggered despite the callback raising an exception + assert op.completed + assert cb3_mock.call_count == 1 + assert cb2_mock.call_count == 1 + assert cb1_mock.call_count == 1 + assert len(op.callback_stack) == 0 + + # The exception raised by the callback was passed to the background exception handler + assert handle_exceptions.handle_background_exception.call_count == 1 + assert handle_exceptions.handle_background_exception.call_args == mocker.call( + arbitrary_exception + ) + + @pytest.mark.it( + "Allows any BaseExceptions raised during execution of a callback to propagate" + ) + def test_callback_raises_base_exception( + self, mocker, arbitrary_base_exception, cls_type, init_kwargs, error + ): + # Set up callback mocks + cb1_mock = mocker.MagicMock() + cb2_mock = mocker.MagicMock(side_effect=arbitrary_base_exception) + cb3_mock = mocker.MagicMock() + + # Attach callbacks to op + init_kwargs["callback"] = cb1_mock + op = cls_type(**init_kwargs) + op.add_callback(cb2_mock) + op.add_callback(cb3_mock) + assert len(op.callback_stack) == 3 + + # BaseException from callback is raised + with pytest.raises(arbitrary_base_exception.__class__) as e_info: + op.complete(error=error) + assert e_info.value is arbitrary_base_exception + + # Due to the BaseException raised during CB2 propagating, CB1 is never triggered + assert cb3_mock.call_count == 1 + assert cb2_mock.call_count == 1 + assert cb1_mock.call_count == 0 + + @pytest.mark.it( + "Halts triggering of callbacks if a callback invokes the .halt_completion() method, leaving untriggered callbacks in the operation's callback stack" + ) + def test_halt_during_callback(self, mocker, cls_type, init_kwargs, error): + def cb2(op, error): + # Halt the operation completion as part of the callback + op.halt_completion() + + # Set up callback mocks + cb1_mock = mocker.MagicMock() + cb2_mock = mocker.MagicMock(side_effect=cb2) + cb3_mock = mocker.MagicMock() + + # Attach callbacks to op + init_kwargs["callback"] = cb1_mock + op = cls_type(**init_kwargs) + op.add_callback(cb2_mock) + op.add_callback(cb3_mock) + assert not op.completed + assert len(op.callback_stack) == 3 + + op.complete(error=error) + + # Callback was NOT completed + assert not op.completed + + # Callback resolution was halted after CB2 due to the operation completion being halted + assert cb3_mock.call_count == 1 + assert cb2_mock.call_count == 1 + assert cb1_mock.call_count == 0 + + assert len(op.callback_stack) == 1 + assert op.callback_stack[0] is cb1_mock + + @pytest.mark.it( + "Marks the operation as fully completed by setting the 'completed' attribute to True, only once all callbacks have been triggered" + ) + def test_marks_complete(self, mocker, op, error): + mocker.spy(handle_exceptions, "handle_background_exception") + + # Set up callback mocks + cb1_mock = mocker.MagicMock() + cb2_mock = mocker.MagicMock() + + def cb(op, error): + assert not op.completed + + cb1_mock.side_effect = cb + cb2_mock.side_effect = cb + + op.add_callback(cb1_mock) + op.add_callback(cb2_mock) + + op.complete(error=error) + assert op.completed + + # Callbacks were called + assert cb1_mock.call_count == 1 + assert cb2_mock.call_count == 1 + + # Because exceptions raised in callbacks are caught and sent to the background exception handler, + # the assertion in the above callbacks won't be able to directly raise AssertionErrors that will + # allow for testing normally. Instead we should check the background_exception_handler to see if + # any of the assertions raised errors and sent them there. + assert handle_exceptions.handle_background_exception.call_count == 0 + + @pytest.mark.it( + "Sends an OperationError to the background exception handler, without making any changes to the operation, if the operation has already been completed" + ) + def test_already_complete(self, mocker, op, error): + mocker.spy(handle_exceptions, "handle_background_exception") + + # Complete the operation + op.complete(error=error) + assert op.completed + assert handle_exceptions.handle_background_exception.call_count == 0 + + # Get the operation state + original_op_err_state = op.error + origianl_op_completion_state = op.completed + + # Attempt to complete the op again + op.complete(error=error) + + # Results in failure + assert handle_exceptions.handle_background_exception.call_count == 1 + assert ( + type(handle_exceptions.handle_background_exception.call_args[0][0]) + is pipeline_exceptions.OperationError + ) + + # The operation state is unchanged + assert op.error is original_op_err_state + assert op.completed is origianl_op_completion_state + + @pytest.mark.it( + "Sends an OperationError to the background exception handler, without making any changes to the operation, if the operation is already in the process of completing" + ) + def test_already_completing(self, mocker, op, error): + mocker.spy(handle_exceptions, "handle_background_exception") + + def cb(op, error): + # Get the operation state + origianl_op_err_state = op.error + original_op_completion_state = op.completed + + # Attempt to complete the operation again while it is already in the process of completing + op.complete(error=error) + + # The operation state is unchanged + assert op.error is origianl_op_err_state + assert op.completed is original_op_completion_state + + cb_mock = mocker.MagicMock(side_effect=cb) + + op.add_callback(cb_mock) + op.complete(error=error) + + # Using the above callback resulted in failure + assert cb_mock.call_count == 1 + assert handle_exceptions.handle_background_exception.call_count == 1 + assert ( + type(handle_exceptions.handle_background_exception.call_args[0][0]) + is pipeline_exceptions.OperationError + ) + + @pytest.mark.it( + "Sends an OperationError to the background exception handler if the operation is somehow completed while still undergoing the process of completion" + ) + def test_invalid_complete_during_completion(self, mocker, op, error): + # This should never happen, as this is an invalid scenario, and could only happen due + # to a bug elsewhere in the code (e.g. manually change the boolean, as in this test) + + mocker.spy(handle_exceptions, "handle_background_exception") + + def cb(op, error): + op.completed = True + + cb_mock = mocker.MagicMock(side_effect=cb) + + op.add_callback(cb_mock) + op.complete(error=error) + + assert cb_mock.call_count == 1 + assert handle_exceptions.handle_background_exception.call_count == 1 + assert ( + type(handle_exceptions.handle_background_exception.call_args[0][0]) + is pipeline_exceptions.OperationError + ) + + @pytest.mark.it( + "Completes the operation successfully (no error) by default if no error is specified" + ) + def test_error_default(self, mocker, cls_type, init_kwargs): + cb_mock = mocker.MagicMock() + init_kwargs["callback"] = cb_mock + op = cls_type(**init_kwargs) + assert not op.completed + + op.complete() + + assert op.completed + assert op.error is None + assert cb_mock.call_count == 1 + # Callback was called passing 'None' as the error + assert cb_mock.call_args == mocker.call(op=op, error=None) + + @pytest.mark.describe("{} - .halt_completion()".format(op_class_under_test.__name__)) + class OperationHaltCompletionTests(OperationTestConfigClass): + @pytest.fixture( + params=["Currently completing with no error", "Currently completing with error"] + ) + def error(self, request, arbitrary_exception): + if request.param == "Currently completing with no error": + return None + else: + return arbitrary_exception + + @pytest.mark.it( + "Marks the operation as no longer completing by setting the 'completing' attribute to False, if the operation is currently in the process of completion" + ) + def test_sets_completing_false(self, mocker, op, error): + mocker.spy(handle_exceptions, "handle_background_exception") + + def cb(op, error): + assert op.completing + assert not op.completed + op.halt_completion() + assert not op.completing + + cb_mock = mocker.MagicMock(side_effect=cb) + op.add_callback(cb_mock) + + op.complete(error=error) + + assert not op.completing + assert not op.completed + assert cb_mock.call_count == 1 + + # Because exceptions raised in callbacks are caught and sent to the background exception handler, + # the assertion in the above callback won't be able to directly raise AssertionErrors that will + # allow for testing normally. Instead we should check the background_exception_handler to see if + # any of the assertions raised errors and sent them there. + assert handle_exceptions.handle_background_exception.call_count == 0 + + @pytest.mark.it( + "Clears the existing error in the operation's 'error' attribute, if the operation is currently in the process of completion with error" + ) + def test_clears_error(self, mocker, op, error): + mocker.spy(handle_exceptions, "handle_background_exception") + completion_error = error + + def cb(op, error): + assert op.completing + assert op.error is completion_error + op.halt_completion() + assert not op.completing + assert op.error is None + + cb_mock = mocker.MagicMock(side_effect=cb) + op.add_callback(cb_mock) + + op.complete(error=completion_error) + + assert op.error is None + assert cb_mock.call_count == 1 + + # Because exceptions raised in callbacks are caught and sent to the background exception handler, + # the assertion in the above callback won't be able to directly raise AssertionErrors that will + # allow for testing normally. Instead we should check the background_exception_handler to see if + # any of the assertions raised errors and sent them there. + assert handle_exceptions.handle_background_exception.call_count == 0 + + @pytest.mark.it( + "Sends an OperationError to the background exception handler if the operation has already been fully completed" + ) + def test_already_completed_op(self, mocker, op): + mocker.spy(handle_exceptions, "handle_background_exception") + + op.complete() + assert op.completed + op.halt_completion() + + assert handle_exceptions.handle_background_exception.call_count == 1 + assert ( + type(handle_exceptions.handle_background_exception.call_args[0][0]) + is pipeline_exceptions.OperationError + ) + + @pytest.mark.it( + "Sends an OperationError to the background exception handler if the operation has never been completed" + ) + def test_never_completed_op(self, mocker, op): + mocker.spy(handle_exceptions, "handle_background_exception") + + op.halt_completion() + + assert handle_exceptions.handle_background_exception.call_count == 1 + assert ( + type(handle_exceptions.handle_background_exception.call_args[0][0]) + is pipeline_exceptions.OperationError + ) + + setattr( + test_module, + "Test{}Instantiation".format(op_class_under_test.__name__), + OperationInstantiationTests, + ) + setattr( + test_module, "Test{}Complete".format(op_class_under_test.__name__), OperationCompleteTests + ) + setattr( + test_module, + "Test{}AddCallback".format(op_class_under_test.__name__), + OperationAddCallbackTests, + ) + setattr( + test_module, + "Test{}HaltCompletion".format(op_class_under_test.__name__), + OperationHaltCompletionTests, + ) + setattr( + test_module, + "Test{}SpawnWorkerOp".format(op_class_under_test.__name__), + OperationSpawnWorkerOpTests, + ) diff --git a/azure-iot-device/tests/common/pipeline/pipeline_stage_test.py b/azure-iot-device/tests/common/pipeline/pipeline_stage_test.py index d0dfca85c..7e8fa05ab 100644 --- a/azure-iot-device/tests/common/pipeline/pipeline_stage_test.py +++ b/azure-iot-device/tests/common/pipeline/pipeline_stage_test.py @@ -5,60 +5,232 @@ # -------------------------------------------------------------------------- import logging import pytest -import inspect -import threading -import concurrent.futures from tests.common.pipeline.helpers import ( all_except, - make_mock_stage, make_mock_op_or_event, - assert_callback_failed, - UnhandledException, - get_arg_count, - add_mock_method_waiter, + StageRunOpTestBase, + StageHandlePipelineEventTestBase, ) -from azure.iot.device.common.pipeline.pipeline_stages_base import PipelineStage -from tests.common.pipeline.pipeline_data_object_test import add_instantiation_test -from azure.iot.device.common.pipeline import pipeline_thread +from azure.iot.device.common.pipeline.pipeline_stages_base import PipelineStage, PipelineRootStage +from azure.iot.device.common.pipeline import pipeline_exceptions +from azure.iot.device.common import handle_exceptions logging.basicConfig(level=logging.DEBUG) def add_base_pipeline_stage_tests( + test_module, + stage_class_under_test, + stage_test_config_class, + extended_stage_instantiation_test_class=None, +): + class StageTestConfig(stage_test_config_class): + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + stage.next = mocker.MagicMock() + stage.previous = mocker.MagicMock() + mocker.spy(stage, "send_op_down") + mocker.spy(stage, "send_event_up") + return stage + + ####################### + # INSTANTIATION TESTS # + ####################### + + @pytest.mark.describe("{} -- Instantiation".format(stage_class_under_test.__name__)) + class StageBaseInstantiationTests(StageTestConfig): + @pytest.mark.it("Initializes 'name' attribute as the classname") + def test_name(self, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + assert stage.name == stage.__class__.__name__ + + @pytest.mark.it("Initializes 'next' attribute as None") + def test_next(self, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + assert stage.next is None + + @pytest.mark.it("Initializes 'previous' attribute as None") + def test_previous(self, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + assert stage.previous is None + + @pytest.mark.it("Initializes 'pipeline_root' attribute as None") + def test_pipeline_root(self, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + assert stage.pipeline_root is None + + if extended_stage_instantiation_test_class: + + class StageInstantiationTests( + extended_stage_instantiation_test_class, StageBaseInstantiationTests + ): + pass + + else: + + class StageInstantiationTests(StageBaseInstantiationTests): + pass + + setattr( + test_module, + "Test{}Instantiation".format(stage_class_under_test.__name__), + StageInstantiationTests, + ) + + ############## + # FLOW TESTS # + ############## + + @pytest.mark.describe("{} - .send_op_down()".format(stage_class_under_test.__name__)) + class StageSendOpDownTests(StageTestConfig): + @pytest.mark.it("Completes the op with failure (PipelineError) if there is no next stage") + def test_fails_op_when_no_next_stage(self, mocker, stage, arbitrary_op): + stage.next = None + + assert not arbitrary_op.completed + + stage.send_op_down(arbitrary_op) + + assert arbitrary_op.completed + assert type(arbitrary_op.error) is pipeline_exceptions.PipelineError + + @pytest.mark.it("Passes the op to the next stage's .run_op() method") + def test_passes_op_to_next_stage(self, mocker, stage, arbitrary_op): + stage.send_op_down(arbitrary_op) + assert stage.next.run_op.call_count == 1 + assert stage.next.run_op.call_args == mocker.call(arbitrary_op) + + @pytest.mark.describe("{} - .send_event_up()".format(stage_class_under_test.__name__)) + class StageSendEventUpTests(StageTestConfig): + @pytest.mark.it( + "Passes the event up to the previous stage's .handle_pipeline_event() method" + ) + def test_calls_handle_pipeline_event(self, stage, arbitrary_event, mocker): + stage.send_event_up(arbitrary_event) + assert stage.previous.handle_pipeline_event.call_count == 1 + assert stage.previous.handle_pipeline_event.call_args == mocker.call(arbitrary_event) + + @pytest.mark.it( + "Sends a PipelineError to the background exception handler instead of sending the event up the pipeline, if there is no previous pipeline stage" + ) + def test_no_previous_stage(self, stage, arbitrary_event, mocker): + stage.previous = None + mocker.spy(handle_exceptions, "handle_background_exception") + + stage.send_event_up(arbitrary_event) + + assert handle_exceptions.handle_background_exception.call_count == 1 + assert ( + type(handle_exceptions.handle_background_exception.call_args[0][0]) + == pipeline_exceptions.PipelineError + ) + + setattr( + test_module, + "Test{}SendOpDown".format(stage_class_under_test.__name__), + StageSendOpDownTests, + ) + setattr( + test_module, + "Test{}SendEventUp".format(stage_class_under_test.__name__), + StageSendEventUpTests, + ) + + ############################################# + # RUN OP / HANDLE_PIPELINE_EVENT BASE TESTS # + ############################################# + + # These tests are only run if the Stage in question has NOT overridden the PipelineStage base + # implementations of ._run_op() and/or ._handle_pipeline_event() + + if stage_class_under_test._run_op is PipelineStage._run_op: + + @pytest.mark.describe("{} - .run_op()".format(stage_class_under_test.__name__)) + class StageRunOpUnhandledOp(StageTestConfig, StageRunOpTestBase): + @pytest.fixture + def op(self, arbitrary_op): + return arbitrary_op + + @pytest.mark.it("Sends the operation down the pipeline") + def test_passes_down(self, mocker, stage, op): + stage.run_op(op) + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) + + setattr( + test_module, + "Test{}RunOpUnhandledOp".format(stage_class_under_test.__name__), + StageRunOpUnhandledOp, + ) + + if stage_class_under_test._handle_pipeline_event is PipelineStage._handle_pipeline_event: + + @pytest.mark.describe( + "{} - .handle_pipeline_event()".format(stage_class_under_test.__name__) + ) + class StageHandlePipelineEventUnhandledEvent( + StageTestConfig, StageHandlePipelineEventTestBase + ): + @pytest.fixture + def event(self, arbitrary_event): + return arbitrary_event + + @pytest.mark.it("Sends the event up the pipeline") + def test_passes_up(self, mocker, stage, event): + stage.handle_pipeline_event(event) + assert stage.send_event_up.call_count == 1 + assert stage.send_event_up.call_args == mocker.call(event) + + setattr( + test_module, + "Test{}HandlePipelineEventUnhandledEvent".format(stage_class_under_test.__name__), + StageHandlePipelineEventUnhandledEvent, + ) + + +############################################################# +# CODE BELOW THIS POINT IS DEPRECATED PENDING TEST OVERHAUL # +############################################################# + +# CT-TODO: Remove this as soon as possible + + +def add_base_pipeline_stage_tests_old( cls, module, all_ops, handled_ops, all_events, handled_events, - methods_that_enter_pipeline_thread=[], - methods_that_can_run_in_any_thread=[], extra_initializer_defaults={}, + positional_arguments=[], + keyword_arguments={}, ): """ Add all of the "basic" tests for validating a pipeline stage. This includes tests for instantiation and tests for properly handling "unhandled" operations and events". """ - add_instantiation_test( - cls=cls, - module=module, - defaults={"name": cls.__name__, "next": None, "previous": None, "pipeline_root": None}, - extra_defaults=extra_initializer_defaults, - ) - add_unknown_ops_tests(cls=cls, module=module, all_ops=all_ops, handled_ops=handled_ops) - add_unknown_events_tests( + # NOTE: this infrastructure has been disabled, resulting in a reduction in test coverage. + # Please port all stage tests to the new version of this function above to remedy + # this problem. + + # add_instantiation_test( + # cls=cls, + # module=module, + # defaults={"name": cls.__name__, "next": None, "previous": None, "pipeline_root": None}, + # extra_defaults=extra_initializer_defaults, + # positional_arguments=positional_arguments, + # keyword_arguments=keyword_arguments, + # ) + _add_unknown_ops_tests(cls=cls, module=module, all_ops=all_ops, handled_ops=handled_ops) + _add_unknown_events_tests( cls=cls, module=module, all_events=all_events, handled_events=handled_events ) - add_pipeline_thread_tests( - cls=cls, - module=module, - methods_that_enter_pipeline_thread=methods_that_enter_pipeline_thread, - methods_that_can_run_in_any_thread=methods_that_can_run_in_any_thread, - ) -def add_unknown_ops_tests(cls, module, all_ops, handled_ops): +def _add_unknown_ops_tests(cls, module, all_ops, handled_ops): """ Add tests for properly handling of "unknown operations," which are operations that aren't handled by a particular stage. These operations should be passed down by any stage into @@ -67,55 +239,31 @@ def add_unknown_ops_tests(cls, module, all_ops, handled_ops): unknown_ops = all_except(all_items=all_ops, items_to_exclude=handled_ops) @pytest.mark.describe("{} - .run_op() -- unknown and unhandled operations".format(cls.__name__)) - class LocalTestObject(object): - @pytest.fixture - def op(self, op_cls, callback): - op = make_mock_op_or_event(op_cls) - op.callback = callback - op.action = "pend" - add_mock_method_waiter(op, "callback") + class LocalTestObject(StageRunOpTestBase): + @pytest.fixture(params=unknown_ops) + def op(self, request, mocker): + op = make_mock_op_or_event(request.param) + op.callback_stack.append(mocker.MagicMock()) return op @pytest.fixture - def stage(self, mocker): - return make_mock_stage(mocker=mocker, stage_to_make=cls) + def stage(self): + if cls == PipelineRootStage: + return cls(None) + else: + return cls() - @pytest.mark.it("Passes unknown operation to next stage") - @pytest.mark.parametrize("op_cls", unknown_ops) - def test_passes_op_to_next_stage(self, op_cls, op, stage): + @pytest.mark.it("Passes unknown operation down to the next stage") + def test_passes_op_to_next_stage(self, mocker, op, stage): + mocker.spy(stage, "send_op_down") stage.run_op(op) - assert stage.next.run_op.call_count == 1 - assert stage.next.run_op.call_args[0][0] == op - - @pytest.mark.it("Fails unknown operation if there is no next stage") - @pytest.mark.parametrize("op_cls", unknown_ops) - def test_passes_op_with_no_next_stage(self, op_cls, op, stage): - stage.next = None - stage.run_op(op) - op.wait_for_callback_to_be_called() - assert_callback_failed(op=op) - - @pytest.mark.it("Catches Exceptions raised when passing unknown operation to next stage") - @pytest.mark.parametrize("op_cls", unknown_ops) - def test_passes_op_to_next_stage_which_throws_exception(self, op_cls, op, stage): - op.action = "exception" - stage.run_op(op) - op.wait_for_callback_to_be_called() - assert_callback_failed(op=op) - - @pytest.mark.it( - "Allows BaseExceptions raised when passing unknown operation to next start to propogate" - ) - @pytest.mark.parametrize("op_cls", unknown_ops) - def test_passes_op_to_next_stage_which_throws_base_exception(self, op_cls, op, stage): - op.action = "base_exception" - with pytest.raises(UnhandledException): - stage.run_op(op) + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) setattr(module, "Test{}UnknownOps".format(cls.__name__), LocalTestObject) -def add_unknown_events_tests(cls, module, all_events, handled_events): +def _add_unknown_events_tests(cls, module, all_events, handled_events): """ Add tests for properly handling of "unknown events," which are events that aren't handled by a particular stage. These operations should be passed up by any stage into @@ -130,140 +278,21 @@ def add_unknown_events_tests(cls, module, all_events, handled_events): @pytest.mark.describe( "{} - .handle_pipeline_event() -- unknown and unhandled events".format(cls.__name__) ) - @pytest.mark.parametrize("event_cls", unknown_events) - class LocalTestObject(object): - @pytest.fixture - def event(self, event_cls): - return make_mock_op_or_event(event_cls) + class LocalTestObject(StageHandlePipelineEventTestBase): + @pytest.fixture(params=unknown_events) + def event(self, request): + return make_mock_op_or_event(request.param) - @pytest.fixture - def stage(self, mocker): - return make_mock_stage(mocker=mocker, stage_to_make=cls) - - @pytest.fixture - def previous(self, stage, mocker): - class PreviousStage(PipelineStage): - def __init__(self): - super(PreviousStage, self).__init__() - self.handle_pipeline_event = mocker.MagicMock() - - def _execute_op(self, op): - pass - - previous = PreviousStage() - stage.previous = previous - return previous - - @pytest.mark.it("Passes unknown event to previous stage") - def test_passes_event_to_previous_stage(self, event_cls, stage, event, previous): - stage.handle_pipeline_event(event) - assert previous.handle_pipeline_event.call_count == 1 - assert previous.handle_pipeline_event.call_args[0][0] == event - - @pytest.mark.it("Calls unhandled exception handler if there is no previous stage") - def test_passes_event_with_no_previous_stage( - self, event_cls, stage, event, unhandled_error_handler - ): - stage.handle_pipeline_event(event) - assert unhandled_error_handler.call_count == 1 - - @pytest.mark.it("Catches Exceptions raised when passing unknown event to previous stage") - def test_passes_event_to_previous_stage_which_throws_exception( - self, event_cls, stage, event, previous, unhandled_error_handler - ): - e = Exception() - previous.handle_pipeline_event.side_effect = e - stage.handle_pipeline_event(event) - assert unhandled_error_handler.call_count == 1 - assert unhandled_error_handler.call_args[0][0] == e - - @pytest.mark.it( - "Allows BaseExceptions raised when passing unknown operation to next start to propogate" - ) - def test_passes_event_to_previous_stage_which_throws_base_exception( - self, event_cls, stage, event, previous, unhandled_error_handler - ): - e = UnhandledException() - previous.handle_pipeline_event.side_effect = e - with pytest.raises(UnhandledException): - stage.handle_pipeline_event(event) - assert unhandled_error_handler.call_count == 0 - - setattr(module, "Test{}UnknownEvents".format(cls.__name__), LocalTestObject) - - -class ThreadLaunchedError(Exception): - pass - - -def add_pipeline_thread_tests( - cls, module, methods_that_enter_pipeline_thread, methods_that_can_run_in_any_thread -): - def does_method_assert_pipeline_thread(method_name): - if method_name.startswith("__"): - return False - elif method_name in methods_that_enter_pipeline_thread: - return False - elif method_name in methods_that_can_run_in_any_thread: - return False - else: - return True - - methods_that_assert_pipeline_thread = [ - x[0] - for x in inspect.getmembers(cls, inspect.isfunction) - if does_method_assert_pipeline_thread(x[0]) - ] - - @pytest.mark.describe("{} - Pipeline threading".format(cls.__name__)) - class LocalTestObject(object): @pytest.fixture def stage(self): return cls() - @pytest.mark.parametrize("method_name", methods_that_assert_pipeline_thread) - @pytest.mark.it("Enforces use of the pipeline thread when calling method") - def test_asserts_in_pipeline(self, stage, method_name, fake_non_pipeline_thread): - func = getattr(stage, method_name) - args = [None for i in (range(get_arg_count(func) - 1))] - with pytest.raises(AssertionError): - func(*args) + @pytest.mark.it("Passes unknown event to previous stage") + def test_passes_event_to_previous_stage(self, stage, event, mocker): + mocker.spy(stage, "send_event_up") + stage.handle_pipeline_event(event) - if methods_that_enter_pipeline_thread: + assert stage.send_event_up.call_count == 1 + assert stage.send_event_up.call_args == mocker.call(event) - @pytest.mark.parametrize("method_name", methods_that_enter_pipeline_thread) - @pytest.mark.it("Automatically enters the pipeline thread when calling method") - def test_enters_pipeline(self, mocker, stage, method_name, fake_non_pipeline_thread): - func = getattr(stage, method_name) - args = [None for i in (range(get_arg_count(func) - 1))] - - # - # We take a bit of a roundabout way to verify that the functuion enters the - # pipeline executor: - # - # 1. we verify that the method got the pipeline executor - # 2. we verify that the method invoked _something_ on the pipeline executor - # - # It's not perfect, but it's good enough. - # - # We do this because: - # 1. We don't have the exact right args to run the method and we don't want - # to add the complexity to get the right args in this test. - # 2. We can't replace the wrapped method with a mock, AFAIK. - # - pipeline_executor = pipeline_thread._get_named_executor("pipeline") - mocker.patch.object(pipeline_executor, "submit") - pipeline_executor.submit.side_effect = ThreadLaunchedError - mocker.spy(pipeline_thread, "_get_named_executor") - - # If the method calls submit on some executor, it will raise a ThreadLaunchedError - with pytest.raises(ThreadLaunchedError): - func(*args) - - # now verify that the code got the pipeline executor and verify that it used that - # executor to launch something. - assert pipeline_thread._get_named_executor.call_count == 1 - assert pipeline_thread._get_named_executor.call_args[0][0] == "pipeline" - assert pipeline_executor.submit.call_count == 1 - - setattr(module, "Test{}PipelineThreading".format(cls.__name__), LocalTestObject) + setattr(module, "Test{}UnknownEvents".format(cls.__name__), LocalTestObject) diff --git a/azure-iot-device/tests/common/pipeline/test_operation_flow.py b/azure-iot-device/tests/common/pipeline/test_operation_flow.py deleted file mode 100644 index 6cf035df6..000000000 --- a/azure-iot-device/tests/common/pipeline/test_operation_flow.py +++ /dev/null @@ -1,133 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import logging -import pytest -from azure.iot.device.common.pipeline import ( - pipeline_thread, - pipeline_stages_base, - pipeline_ops_base, - pipeline_events_base, -) -from azure.iot.device.common.pipeline.operation_flow import ( - delegate_to_different_op, - complete_op, - pass_op_to_next_stage, -) -from tests.common.pipeline.helpers import ( - make_mock_stage, - assert_callback_failed, - assert_callback_succeeded, - UnhandledException, -) - -logging.basicConfig(level=logging.DEBUG) - - -# This fixture makes it look like all test in this file tests are running -# inside the pipeline thread. Because this is an autouse fixture, we -# manually add it to the individual test.py files that need it. If, -# instead, we had added it to some conftest.py, it would be applied to -# every tests in every file and we don't want that. -@pytest.fixture(autouse=True) -def apply_fake_pipeline_thread(fake_pipeline_thread): - pass - - -class MockPipelineStage(pipeline_stages_base.PipelineStage): - def _execute_op(self, op): - pass_op_to_next_stage(self, op) - - -@pytest.fixture -def stage(mocker): - return make_mock_stage(mocker, MockPipelineStage) - - -@pytest.mark.describe("delegate_to_different_op()") -class TestContineWithDifferntOp(object): - @pytest.mark.it("Runs the new op and does not continue running the original op") - def test_runs_new_op(self, mocker, stage, op, new_op): - delegate_to_different_op(stage, original_op=op, new_op=new_op) - assert stage.next.run_op.call_count == 1 - assert stage.next.run_op.call_args == mocker.call(new_op) - - @pytest.mark.it("Completes the original op after the new op completes") - def test_completes_original_op_after_new_op_completes(self, stage, op, new_op, callback): - op.callback = callback - new_op.action = "pend" - - delegate_to_different_op(stage, original_op=op, new_op=new_op) - assert callback.call_count == 0 # because new_op is pending - - complete_op(stage.next, new_op) - assert_callback_succeeded(op=op) - - @pytest.mark.it("Returns the new op failure in the original op if new op fails") - def test_returns_new_op_failure_in_original_op(self, stage, op, new_op, callback): - op.callback = callback - new_op.action = "fail" - delegate_to_different_op(stage, original_op=op, new_op=new_op) - assert_callback_failed(op=op, error=new_op.error) - - -@pytest.mark.describe("pass_op_to_next_stage()") -class TestContinueOp(object): - @pytest.mark.it("Completes the op without continuing if the op has an error") - def test_completes_op_with_error(self, mocker, stage, op, fake_exception, callback): - op.error = fake_exception - op.callback = callback - pass_op_to_next_stage(stage, op) - assert_callback_failed(op=op, error=fake_exception) - assert stage.next.run_op.call_count == 0 - - @pytest.mark.it("Fails the op if there is no next stage") - def test_fails_op_when_no_next_stage(self, stage, op, callback): - op.callback = callback - stage.next = None - pass_op_to_next_stage(stage, op) - assert_callback_failed(op=op) - pass - - @pytest.mark.it("Passes the op to the next stage") - def test_passes_op_to_next_stage(self, mocker, stage, op, callback): - pass_op_to_next_stage(stage, op) - assert stage.next.run_op.call_count == 1 - assert stage.next.run_op.call_args == mocker.call(op) - - -@pytest.mark.describe("complete_op()") -class TestCompleteOp(object): - @pytest.mark.it("Calls the op callback on success") - def test_calls_callback_on_success(self, stage, op, callback): - op.callback = callback - complete_op(stage, op) - assert_callback_succeeded(op) - - @pytest.mark.it("Calls the op callback on failure") - def test_calls_callback_on_error(self, stage, op, callback, fake_exception): - op.error = fake_exception - op.callback = callback - complete_op(stage, op) - assert_callback_failed(op=op, error=fake_exception) - - @pytest.mark.it( - "Handles Exceptions raised in operation callback and passes them to the unhandled error handler" - ) - def test_op_callback_raises_exception( - self, stage, op, fake_exception, mocker, unhandled_error_handler - ): - op.callback = mocker.Mock(side_effect=fake_exception) - complete_op(stage, op) - assert op.callback.call_count == 1 - assert op.callback.call_args == mocker.call(op) - assert unhandled_error_handler.call_count == 1 - assert unhandled_error_handler.call_args == mocker.call(fake_exception) - - @pytest.mark.it("Allows any BaseExceptions raised in operation callback to propagate") - def test_op_callback_raises_base_exception(self, stage, op, fake_base_exception, mocker): - op.callback = mocker.Mock(side_effect=fake_base_exception) - with pytest.raises(UnhandledException): - complete_op(stage, op) diff --git a/azure-iot-device/tests/common/pipeline/test_pipeline_events_base.py b/azure-iot-device/tests/common/pipeline/test_pipeline_events_base.py index 354d4f83c..dabd6c4bb 100644 --- a/azure-iot-device/tests/common/pipeline/test_pipeline_events_base.py +++ b/azure-iot-device/tests/common/pipeline/test_pipeline_events_base.py @@ -7,7 +7,7 @@ import sys import pytest import logging from azure.iot.device.common.pipeline import pipeline_events_base -from tests.common.pipeline import pipeline_data_object_test +from tests.common.pipeline import pipeline_event_test logging.basicConfig(level=logging.DEBUG) this_module = sys.modules[__name__] @@ -21,8 +21,8 @@ class TestPipelineOperation(object): pipeline_events_base.PipelineEvent() -pipeline_data_object_test.add_event_test( - cls=pipeline_events_base.IotResponseEvent, +pipeline_event_test.add_event_test( + cls=pipeline_events_base.ResponseEvent, module=this_module, positional_arguments=["request_id", "status_code", "response_body"], keyword_arguments={}, diff --git a/azure-iot-device/tests/common/pipeline/test_pipeline_events_mqtt.py b/azure-iot-device/tests/common/pipeline/test_pipeline_events_mqtt.py index 47efb3f0e..bf724bb23 100644 --- a/azure-iot-device/tests/common/pipeline/test_pipeline_events_mqtt.py +++ b/azure-iot-device/tests/common/pipeline/test_pipeline_events_mqtt.py @@ -6,13 +6,13 @@ import sys import logging from azure.iot.device.common.pipeline import pipeline_events_mqtt -from tests.common.pipeline import pipeline_data_object_test +from tests.common.pipeline import pipeline_event_test logging.basicConfig(level=logging.DEBUG) this_module = sys.modules[__name__] -pipeline_data_object_test.add_event_test( +pipeline_event_test.add_event_test( cls=pipeline_events_mqtt.IncomingMQTTMessageEvent, module=this_module, positional_arguments=["topic", "payload"], diff --git a/azure-iot-device/tests/common/pipeline/test_pipeline_ops_base.py b/azure-iot-device/tests/common/pipeline/test_pipeline_ops_base.py index 76fe59846..cea45d852 100644 --- a/azure-iot-device/tests/common/pipeline/test_pipeline_ops_base.py +++ b/azure-iot-device/tests/common/pipeline/test_pipeline_ops_base.py @@ -7,71 +7,268 @@ import sys import pytest import logging from azure.iot.device.common.pipeline import pipeline_ops_base -from tests.common.pipeline import pipeline_data_object_test +from tests.common.pipeline import pipeline_ops_test this_module = sys.modules[__name__] logging.basicConfig(level=logging.DEBUG) +pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") -@pytest.mark.describe("PipelineOperation") -class TestPipelineOperation(object): - @pytest.mark.it("Can't be instantiated") - def test_instantiate(self): - with pytest.raises(TypeError): - pipeline_ops_base.PipelineOperation() +class ConnectOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_base.ConnectOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = {"callback": mocker.MagicMock()} + return kwargs -pipeline_data_object_test.add_operation_test( - cls=pipeline_ops_base.ConnectOperation, - module=this_module, - positional_arguments=[], - keyword_arguments={"callback": None}, +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_base.ConnectOperation, + op_test_config_class=ConnectOperationTestConfig, ) -pipeline_data_object_test.add_operation_test( - cls=pipeline_ops_base.DisconnectOperation, - module=this_module, - positional_arguments=[], - keyword_arguments={"callback": None}, + + +class DisconnectOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_base.DisconnectOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = {"callback": mocker.MagicMock()} + return kwargs + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_base.DisconnectOperation, + op_test_config_class=DisconnectOperationTestConfig, ) -pipeline_data_object_test.add_operation_test( - cls=pipeline_ops_base.ReconnectOperation, - module=this_module, - positional_arguments=[], - keyword_arguments={"callback": None}, + + +class ReauthorizeConnectionOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_base.ReauthorizeConnectionOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = {"callback": mocker.MagicMock()} + return kwargs + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_base.ReauthorizeConnectionOperation, + op_test_config_class=ReauthorizeConnectionOperationTestConfig, ) -pipeline_data_object_test.add_operation_test( - cls=pipeline_ops_base.EnableFeatureOperation, - module=this_module, - positional_arguments=["feature_name"], - keyword_arguments={"callback": None}, + + +class EnableFeatureOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_base.EnableFeatureOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = {"feature_name": "some_feature", "callback": mocker.MagicMock()} + return kwargs + + +class EnableFeatureInstantiationTests(EnableFeatureOperationTestConfig): + @pytest.mark.it( + "Initializes 'feature_name' attribute with the provided 'feature_name' parameter" + ) + def test_feature_name(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.feature_name == init_kwargs["feature_name"] + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_base.EnableFeatureOperation, + op_test_config_class=EnableFeatureOperationTestConfig, + extended_op_instantiation_test_class=EnableFeatureInstantiationTests, ) -pipeline_data_object_test.add_operation_test( - cls=pipeline_ops_base.DisableFeatureOperation, - module=this_module, - positional_arguments=["feature_name"], - keyword_arguments={"callback": None}, + + +class DisableFeatureOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_base.DisableFeatureOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = {"feature_name": "some_feature", "callback": mocker.MagicMock()} + return kwargs + + +class DisableFeatureInstantiationTests(DisableFeatureOperationTestConfig): + @pytest.mark.it( + "Initializes 'feature_name' attribute with the provided 'feature_name' parameter" + ) + def test_feature_name(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.feature_name == init_kwargs["feature_name"] + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_base.DisableFeatureOperation, + op_test_config_class=DisableFeatureOperationTestConfig, + extended_op_instantiation_test_class=DisableFeatureInstantiationTests, ) -pipeline_data_object_test.add_operation_test( - cls=pipeline_ops_base.UpdateSasTokenOperation, - module=this_module, - positional_arguments=["sas_token"], - keyword_arguments={"callback": None}, + + +class UpdateSasTokenOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_base.UpdateSasTokenOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = {"sas_token": "some_token", "callback": mocker.MagicMock()} + return kwargs + + +class UpdateSasTokenOperationInstantiationTests(UpdateSasTokenOperationTestConfig): + @pytest.mark.it("Initializes 'sas_token' attribute with the provided 'sas_token' parameter") + def test_sas_token(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.sas_token == init_kwargs["sas_token"] + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_base.UpdateSasTokenOperation, + op_test_config_class=UpdateSasTokenOperationTestConfig, + extended_op_instantiation_test_class=UpdateSasTokenOperationInstantiationTests, ) -pipeline_data_object_test.add_operation_test( - cls=pipeline_ops_base.SendIotRequestAndWaitForResponseOperation, - module=this_module, - positional_arguments=["request_type", "method", "resource_location", "request_body"], - keyword_arguments={"callback": None}, + + +class RequestAndResponseOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_base.RequestAndResponseOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = { + "request_type": "some_request_type", + "method": "SOME_METHOD", + "resource_location": "some/resource/location", + "request_body": "some_request_body", + "callback": mocker.MagicMock(), + } + return kwargs + + +class RequestAndResponseOperationInstantiationTests(RequestAndResponseOperationTestConfig): + @pytest.mark.it( + "Initializes 'request_type' attribute with the provided 'request_type' parameter" + ) + def test_request_type(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.request_type == init_kwargs["request_type"] + + @pytest.mark.it("Initializes 'method' attribute with the provided 'method' parameter") + def test_method_type(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.method == init_kwargs["method"] + + @pytest.mark.it( + "Initializes 'resource_location' attribute with the provided 'resource_location' parameter" + ) + def test_resource_location(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.resource_location == init_kwargs["resource_location"] + + @pytest.mark.it( + "Initializes 'request_body' attribute with the provided 'request_body' parameter" + ) + def test_request_body(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.request_body == init_kwargs["request_body"] + + @pytest.mark.it("Initializes 'status_code' attribute to None") + def test_status_code(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.status_code is None + + @pytest.mark.it("Initializes 'response_body' attribute to None") + def test_response_body(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.response_body is None + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_base.RequestAndResponseOperation, + op_test_config_class=RequestAndResponseOperationTestConfig, + extended_op_instantiation_test_class=RequestAndResponseOperationInstantiationTests, ) -pipeline_data_object_test.add_operation_test( - cls=pipeline_ops_base.SendIotRequestOperation, - module=this_module, - positional_arguments=[ - "request_type", - "method", - "resource_location", - "request_body", - "request_id", - ], - keyword_arguments={"callback": None}, + + +class RequestOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_base.RequestOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = { + "method": "SOME_METHOD", + "resource_location": "some/resource/location", + "request_type": "some_request_type", + "request_body": "some_request_body", + "request_id": "some_request_id", + "callback": mocker.MagicMock(), + } + return kwargs + + +class RequestOperationInstantiationTests(RequestOperationTestConfig): + @pytest.mark.it("Initializes the 'method' attribute with the provided 'method' parameter") + def test_method(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.method == init_kwargs["method"] + + @pytest.mark.it( + "Initializes the 'resource_location' attribute with the provided 'resource_location' parameter" + ) + def test_resource_location(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.resource_location == init_kwargs["resource_location"] + + @pytest.mark.it( + "Initializes the 'request_type' attribute with the provided 'request_type' parameter" + ) + def test_request_type(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.request_type == init_kwargs["request_type"] + + @pytest.mark.it( + "Initializes the 'request_body' attribute with the provided 'request_body' parameter" + ) + def test_request_body(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.request_body == init_kwargs["request_body"] + + @pytest.mark.it( + "Initializes the 'request_id' attribute with the provided 'request_id' parameter" + ) + def test_request_id(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.request_id == init_kwargs["request_id"] + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_base.RequestOperation, + op_test_config_class=RequestOperationTestConfig, + extended_op_instantiation_test_class=RequestOperationInstantiationTests, ) diff --git a/azure-iot-device/tests/common/pipeline/test_pipeline_ops_http.py b/azure-iot-device/tests/common/pipeline/test_pipeline_ops_http.py new file mode 100644 index 000000000..6a69ab2c0 --- /dev/null +++ b/azure-iot-device/tests/common/pipeline/test_pipeline_ops_http.py @@ -0,0 +1,157 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import pytest +import sys +import logging +from azure.iot.device.common.pipeline import pipeline_ops_http +from tests.common.pipeline import pipeline_ops_test + +logging.basicConfig(level=logging.DEBUG) +this_module = sys.modules[__name__] +pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") + + +class SetHTTPConnectionArgsOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_http.SetHTTPConnectionArgsOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = { + "hostname": "some_hostname", + "callback": mocker.MagicMock(), + "server_verification_cert": "some_server_verification_cert", + "client_cert": "some_client_cert", + "sas_token": "some_sas_token", + } + return kwargs + + +class SetHTTPConnectionArgsOperationInstantiationTests(SetHTTPConnectionArgsOperationTestConfig): + @pytest.mark.it("Initializes 'hostname' attribute with the provided 'hostname' parameter") + def test_hostname(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.hostname == init_kwargs["hostname"] + + @pytest.mark.it( + "Initializes 'server_verification_cert' attribute with the provided 'server_verification_cert' parameter" + ) + def test_server_verification_cert(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.server_verification_cert == init_kwargs["server_verification_cert"] + + @pytest.mark.it( + "Initializes 'server_verification_cert' attribute to None if no 'server_verification_cert' parameter is provided" + ) + def test_server_verification_cert_default(self, cls_type, init_kwargs): + del init_kwargs["server_verification_cert"] + op = cls_type(**init_kwargs) + assert op.server_verification_cert is None + + @pytest.mark.it("Initializes 'client_cert' attribute with the provided 'client_cert' parameter") + def test_client_cert(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.client_cert == init_kwargs["client_cert"] + + @pytest.mark.it( + "Initializes 'client_cert' attribute to None if no 'client_cert' parameter is provided" + ) + def test_client_cert_default(self, cls_type, init_kwargs): + del init_kwargs["client_cert"] + op = cls_type(**init_kwargs) + assert op.client_cert is None + + @pytest.mark.it("Initializes 'sas_token' attribute with the provided 'sas_token' parameter") + def test_sas_token(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.sas_token == init_kwargs["sas_token"] + + @pytest.mark.it( + "Initializes 'sas_token' attribute to None if no 'sas_token' parameter is provided" + ) + def test_sas_token_default(self, cls_type, init_kwargs): + del init_kwargs["sas_token"] + op = cls_type(**init_kwargs) + assert op.sas_token is None + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_http.SetHTTPConnectionArgsOperation, + op_test_config_class=SetHTTPConnectionArgsOperationTestConfig, + extended_op_instantiation_test_class=SetHTTPConnectionArgsOperationInstantiationTests, +) + + +class HTTPRequestAndResponseOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_http.HTTPRequestAndResponseOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = { + "method": "some_topic", + "path": "some_path", + "headers": {"some_key": "some_value"}, + "body": "some_body", + "query_params": "some_query_params", + "callback": mocker.MagicMock(), + } + return kwargs + + +class HTTPRequestAndResponseOperationInstantiationTests(HTTPRequestAndResponseOperationTestConfig): + @pytest.mark.it("Initializes 'method' attribute with the provided 'method' parameter") + def test_method(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.method == init_kwargs["method"] + + @pytest.mark.it("Initializes 'path' attribute with the provided 'path' parameter") + def test_path(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.path == init_kwargs["path"] + + @pytest.mark.it("Initializes 'headers' attribute with the provided 'headers' parameter") + def test_headers(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.headers == init_kwargs["headers"] + + @pytest.mark.it("Initializes 'body' attribute with the provided 'body' parameter") + def test_body(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.body == init_kwargs["body"] + + @pytest.mark.it( + "Initializes 'query_params' attribute with the provided 'query_params' parameter" + ) + def test_query_params(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.query_params == init_kwargs["query_params"] + + @pytest.mark.it("Initializes 'status_code' attribute as None") + def test_status_code(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.status_code is None + + @pytest.mark.it("Initializes 'response_body' attribute as None") + def test_response_body(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.response_body is None + + @pytest.mark.it("Initializes 'reason' attribute as None") + def test_reason(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.reason is None + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_http.SetHTTPConnectionArgsOperation, + op_test_config_class=HTTPRequestAndResponseOperationTestConfig, + extended_op_instantiation_test_class=HTTPRequestAndResponseOperationInstantiationTests, +) diff --git a/azure-iot-device/tests/common/pipeline/test_pipeline_ops_mqtt.py b/azure-iot-device/tests/common/pipeline/test_pipeline_ops_mqtt.py index 9c18e096c..52c6f2e85 100644 --- a/azure-iot-device/tests/common/pipeline/test_pipeline_ops_mqtt.py +++ b/azure-iot-device/tests/common/pipeline/test_pipeline_ops_mqtt.py @@ -3,38 +3,215 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +import pytest import sys import logging from azure.iot.device.common.pipeline import pipeline_ops_mqtt -from tests.common.pipeline import pipeline_data_object_test +from tests.common.pipeline import pipeline_ops_test logging.basicConfig(level=logging.DEBUG) this_module = sys.modules[__name__] +pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") -pipeline_data_object_test.add_operation_test( - cls=pipeline_ops_mqtt.SetMQTTConnectionArgsOperation, - module=this_module, - positional_arguments=["client_id", "hostname", "username"], - keyword_arguments={"ca_cert": None, "client_cert": None, "sas_token": None, "callback": None}, + +class SetMQTTConnectionArgsOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_mqtt.SetMQTTConnectionArgsOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = { + "client_id": "some_client_id", + "hostname": "some_hostname", + "username": "some_username", + "callback": mocker.MagicMock(), + "server_verification_cert": "some_server_verification_cert", + "client_cert": "some_client_cert", + "sas_token": "some_sas_token", + } + return kwargs + + +class SetMQTTConnectionArgsOperationInstantiationTests(SetMQTTConnectionArgsOperationTestConfig): + @pytest.mark.it("Initializes 'client_id' attribute with the provided 'client_id' parameter") + def test_client_id(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.client_id == init_kwargs["client_id"] + + @pytest.mark.it("Initializes 'hostname' attribute with the provided 'hostname' parameter") + def test_hostname(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.hostname == init_kwargs["hostname"] + + @pytest.mark.it("Initializes 'username' attribute with the provided 'username' parameter") + def test_username(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.username == init_kwargs["username"] + + @pytest.mark.it( + "Initializes 'server_verification_cert' attribute with the provided 'server_verification_cert' parameter" + ) + def test_server_verification_cert(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.server_verification_cert == init_kwargs["server_verification_cert"] + + @pytest.mark.it( + "Initializes 'server_verification_cert' attribute to None if no 'server_verification_cert' parameter is provided" + ) + def test_server_verification_cert_default(self, cls_type, init_kwargs): + del init_kwargs["server_verification_cert"] + op = cls_type(**init_kwargs) + assert op.server_verification_cert is None + + @pytest.mark.it("Initializes 'client_cert' attribute with the provided 'client_cert' parameter") + def test_client_cert(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.client_cert == init_kwargs["client_cert"] + + @pytest.mark.it( + "Initializes 'client_cert' attribute to None if no 'client_cert' parameter is provided" + ) + def test_client_cert_default(self, cls_type, init_kwargs): + del init_kwargs["client_cert"] + op = cls_type(**init_kwargs) + assert op.client_cert is None + + @pytest.mark.it("Initializes 'sas_token' attribute with the provided 'sas_token' parameter") + def test_sas_token(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.sas_token == init_kwargs["sas_token"] + + @pytest.mark.it( + "Initializes 'sas_token' attribute to None if no 'sas_token' parameter is provided" + ) + def test_sas_token_default(self, cls_type, init_kwargs): + del init_kwargs["sas_token"] + op = cls_type(**init_kwargs) + assert op.sas_token is None + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_mqtt.SetMQTTConnectionArgsOperation, + op_test_config_class=SetMQTTConnectionArgsOperationTestConfig, + extended_op_instantiation_test_class=SetMQTTConnectionArgsOperationInstantiationTests, ) -pipeline_data_object_test.add_operation_test( - cls=pipeline_ops_mqtt.MQTTPublishOperation, - module=this_module, - positional_arguments=["topic", "payload"], - keyword_arguments={"callback": None}, - extra_defaults={"needs_connection": True}, + + +class MQTTPublishOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_mqtt.MQTTPublishOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = {"topic": "some_topic", "payload": "some_payload", "callback": mocker.MagicMock()} + return kwargs + + +class MQTTPublishOperationInstantiationTests(MQTTPublishOperationTestConfig): + @pytest.mark.it("Initializes 'topic' attribute with the provided 'topic' parameter") + def test_topic(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.topic == init_kwargs["topic"] + + @pytest.mark.it("Initializes 'payload' attribute with the provided 'payload' parameter") + def test_payload(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.payload == init_kwargs["payload"] + + @pytest.mark.it("Initializes 'needs_connection' attribute as True") + def test_needs_connection(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.needs_connection is True + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_mqtt.MQTTPublishOperation, + op_test_config_class=MQTTPublishOperationTestConfig, + extended_op_instantiation_test_class=MQTTPublishOperationInstantiationTests, ) -pipeline_data_object_test.add_operation_test( - cls=pipeline_ops_mqtt.MQTTSubscribeOperation, - module=this_module, - positional_arguments=["topic"], - keyword_arguments={"callback": None}, - extra_defaults={"needs_connection": True}, + + +class MQTTSubscribeOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_mqtt.MQTTSubscribeOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = {"topic": "some_topic", "callback": mocker.MagicMock()} + return kwargs + + +class MQTTSubscribeOperationInstantiationTests(MQTTSubscribeOperationTestConfig): + @pytest.mark.it("Initializes 'topic' attribute with the provided 'topic' parameter") + def test_topic(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.topic == init_kwargs["topic"] + + @pytest.mark.it("Initializes 'needs_connection' attribute as True") + def test_needs_connection(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.needs_connection is True + + @pytest.mark.it("Initializes 'timeout_timer' attribute as None") + def test_timeout_timer(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.timeout_timer is None + + @pytest.mark.it("Initializes 'retry_timer' attribute as None") + def test_retry_timer(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.retry_timer is None + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_mqtt.MQTTSubscribeOperation, + op_test_config_class=MQTTSubscribeOperationTestConfig, + extended_op_instantiation_test_class=MQTTSubscribeOperationInstantiationTests, ) -pipeline_data_object_test.add_operation_test( - cls=pipeline_ops_mqtt.MQTTUnsubscribeOperation, - module=this_module, - positional_arguments=["topic"], - keyword_arguments={"callback": None}, - extra_defaults={"needs_connection": True}, + + +class MQTTUnsubscribeOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_mqtt.MQTTUnsubscribeOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = {"topic": "some_topic", "callback": mocker.MagicMock()} + return kwargs + + +class MQTTUnsubscribeOperationInstantiationTests(MQTTUnsubscribeOperationTestConfig): + @pytest.mark.it("Initializes 'topic' attribute with the provided 'topic' parameter") + def test_topic(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.topic == init_kwargs["topic"] + + @pytest.mark.it("Initializes 'needs_connection' attribute as True") + def test_needs_connection(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.needs_connection is True + + @pytest.mark.it("Initializes 'timeout_timer' attribute as None") + def test_timeout_timer(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.timeout_timer is None + + @pytest.mark.it("Initializes 'retry_timer' attribute as None") + def test_retry_timer(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.retry_timer is None + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_mqtt.MQTTUnsubscribeOperation, + op_test_config_class=MQTTUnsubscribeOperationTestConfig, + extended_op_instantiation_test_class=MQTTUnsubscribeOperationInstantiationTests, ) diff --git a/azure-iot-device/tests/common/pipeline/test_pipeline_stages_base.py b/azure-iot-device/tests/common/pipeline/test_pipeline_stages_base.py index 22b222dcb..eaa18eb8d 100644 --- a/azure-iot-device/tests/common/pipeline/test_pipeline_stages_base.py +++ b/azure-iot-device/tests/common/pipeline/test_pipeline_stages_base.py @@ -4,719 +4,2788 @@ # license information. # -------------------------------------------------------------------------- import logging +import copy +import time import pytest import sys import six import threading +import random +import uuid from six.moves import queue +from azure.iot.device.common import transport_exceptions, handle_exceptions from azure.iot.device.common.pipeline import ( pipeline_stages_base, pipeline_ops_base, pipeline_ops_mqtt, pipeline_events_base, - operation_flow, -) -from tests.common.pipeline.helpers import ( - make_mock_stage, - assert_callback_failed, - assert_callback_succeeded, - UnhandledException, - all_common_ops, - all_common_events, + pipeline_exceptions, ) +from .helpers import StageRunOpTestBase, StageHandlePipelineEventTestBase +from .fixtures import ArbitraryOperation from tests.common.pipeline import pipeline_stage_test this_module = sys.modules[__name__] logging.basicConfig(level=logging.DEBUG) +pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") -# This fixture makes it look like all test in this file tests are running -# inside the pipeline thread. Because this is an autouse fixture, we -# manually add it to the individual test.py files that need it. If, -# instead, we had added it to some conftest.py, it would be applied to -# every tests in every file and we don't want that. -@pytest.fixture(autouse=True) -def apply_fake_pipeline_thread(fake_pipeline_thread): +################### +# COMMON FIXTURES # +################### +@pytest.fixture +def mock_timer(mocker): + return mocker.patch.object(threading, "Timer") + + +# Not a fixture, but useful for sharing +def fake_callback(*args, **kwargs): pass -# Workaround for flake8. A class with this name is actually created inside -# add_base_pipeline_stage_test, but flake8 doesn't know that -class TestPipelineRootStagePipelineThreading: - pass +####################### +# PIPELINE ROOT STAGE # +####################### + + +class PipelineRootStageTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_stages_base.PipelineRootStage + + @pytest.fixture + def init_kwargs(self, mocker): + return {"pipeline_configuration": mocker.MagicMock()} + + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + return stage + + +class PipelineRootStageInstantiationTests(PipelineRootStageTestConfig): + @pytest.mark.it("Initializes 'on_pipeline_event_handler' as None") + def test_on_pipeline_event_handler(self, init_kwargs): + stage = pipeline_stages_base.PipelineRootStage(**init_kwargs) + assert stage.on_pipeline_event_handler is None + + @pytest.mark.it("Initializes 'on_connected_handler' as None") + def test_on_connected_handler(self, init_kwargs): + stage = pipeline_stages_base.PipelineRootStage(**init_kwargs) + assert stage.on_connected_handler is None + + @pytest.mark.it("Initializes 'on_disconnected_handler' as None") + def test_on_disconnected_handler(self, init_kwargs): + stage = pipeline_stages_base.PipelineRootStage(**init_kwargs) + assert stage.on_disconnected_handler is None + + @pytest.mark.it("Initializes 'connected' as False") + def test_connected(self, init_kwargs): + stage = pipeline_stages_base.PipelineRootStage(**init_kwargs) + assert stage.connected is False + + @pytest.mark.it( + "Initializes 'pipeline_configuration' with the provided 'pipeline_configuration' parameter" + ) + def test_pipeline_configuration(self, init_kwargs): + stage = pipeline_stages_base.PipelineRootStage(**init_kwargs) + assert stage.pipeline_configuration is init_kwargs["pipeline_configuration"] pipeline_stage_test.add_base_pipeline_stage_tests( - cls=pipeline_stages_base.PipelineRootStage, - module=this_module, - all_ops=all_common_ops, - handled_ops=[], - all_events=all_common_events, - handled_events=all_common_events, - methods_that_can_run_in_any_thread=["append_stage", "run_op"], - extra_initializer_defaults={ - "on_pipeline_event_handler": None, - "on_connected_handler": None, - "on_disconnected_handler": None, - "connected": False, - }, + test_module=this_module, + stage_class_under_test=pipeline_stages_base.PipelineRootStage, + stage_test_config_class=PipelineRootStageTestConfig, + extended_stage_instantiation_test_class=PipelineRootStageInstantiationTests, ) -@pytest.mark.it("Calls operation callback in callback thread") -def _test_pipeline_root_runs_callback_in_callback_thread(self, stage, mocker): - # the stage fixture comes from the TestPipelineRootStagePipelineThreading object that - # this test method gets added to, so it's a PipelineRootStage object - stage.pipeline_root = stage - callback_called = threading.Event() +@pytest.mark.describe("PipelineRootStage - .append_stage()") +class TestPipelineRootStageAppendStage(PipelineRootStageTestConfig): + @pytest.mark.it("Appends the provided stage to the tail of the pipeline") + @pytest.mark.parametrize( + "pipeline_len", + [ + pytest.param(1, id="Pipeline Length: 1"), + pytest.param(2, id="Pipeline Length: 2"), + pytest.param(3, id="Pipeline Length: 3"), + pytest.param(10, id="Pipeline Length: 10"), + pytest.param(random.randint(4, 99), id="Randomly chosen Pipeline Length"), + ], + ) + def test_appends_new_stage(self, stage, pipeline_len): + class ArbitraryStage(pipeline_stages_base.PipelineStage): + pass - def callback(op): - assert threading.current_thread().name == "callback" - callback_called.set() - - op = pipeline_ops_base.ConnectOperation(callback=callback) - stage.run_op(op) - callback_called.wait() + assert stage.next is None + assert stage.previous is None + prev_tail = stage + root = stage + for i in range(0, pipeline_len): + new_stage = ArbitraryStage() + stage.append_stage(new_stage) + assert prev_tail.next is new_stage + assert new_stage.previous is prev_tail + assert new_stage.pipeline_root is root + prev_tail = new_stage -@pytest.mark.it("Runs operation in pipeline thread") -def _test_pipeline_root_runs_operation_in_pipeline_thread( - self, mocker, stage, op, fake_non_pipeline_thread +# NOTE 1: Because the Root stage overrides the parent implementation, we must test it here +# (even though it's the same test). +# NOTE 2: Currently this implementation does some other things with threads, but we do not +# currently have a thread testing strategy, so it is untested for now. +@pytest.mark.describe("PipelineRootStage - .run_op()") +class TestPipelineRootStageRunOp(PipelineRootStageTestConfig): + @pytest.fixture + def op(self, arbitrary_op): + return arbitrary_op + + @pytest.mark.it("Sends the operation down") + def test_sends_op_down(self, mocker, stage, op): + stage.run_op(op) + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) + + +@pytest.mark.describe("PipelineRootStage - .handle_pipeline_event() -- Called with ConnectedEvent") +class TestPipelineRootStageHandlePipelineEventWithConnectedEvent( + PipelineRootStageTestConfig, StageHandlePipelineEventTestBase ): - # the stage fixture comes from the TestPipelineRootStagePipelineThreading object that - # this test method gets added to, so it's a PipelineRootStage object - assert threading.current_thread().name != "pipeline" + @pytest.fixture + def event(self): + return pipeline_events_base.ConnectedEvent() - def mock_execute_op(self, op): - print("mock_execute_op called") - assert threading.current_thread().name == "pipeline" - op.callback(op) + @pytest.mark.it("Sets the 'connected' attribute to True") + def test_set_connected_true(self, stage, event): + assert not stage.connected + stage.handle_pipeline_event(event) + assert stage.connected - mock_execute_op = mocker.MagicMock(mock_execute_op) - stage._execute_op = mock_execute_op - - stage.run_op(op) - assert mock_execute_op.call_count == 1 + @pytest.mark.it("Invokes the 'on_connected_handler' handler function, if set") + def test_invoke_handler(self, mocker, stage, event): + mock_handler = mocker.MagicMock() + stage.on_connected_handler = mock_handler + stage.handle_pipeline_event(event) + time.sleep(0.1) # CT-TODO / BK-TODO: get rid of this + assert mock_handler.call_count == 1 + assert mock_handler.call_args == mocker.call() -@pytest.mark.it("Calls on_connected_handler in callback thread") -def _test_pipeline_root_runs_on_connected_in_callback_thread(self, stage, mocker): - stage.pipeline_root = stage - callback_called = threading.Event() - - def callback(*arg, **argv): - assert threading.current_thread().name == "callback" - callback_called.set() - - stage.on_connected_handler = callback - - stage.on_connected() - callback_called.wait() - - -@pytest.mark.it("Calls on_disconnected_handler in callback thread") -def _test_pipeline_root_runs_on_disconnected_in_callback_thread(self, stage, mocker): - stage.pipeline_root = stage - callback_called = threading.Event() - - def callback(*arg, **argv): - assert threading.current_thread().name == "callback" - callback_called.set() - - stage.on_disconnected_handler = callback - - stage.on_disconnected() - callback_called.wait() - - -@pytest.mark.it("Calls on_event_received_handler in callback thread") -def _test_pipeline_root_runs_on_event_received_in_callback_thread(self, stage, mocker, event): - stage.pipeline_root = stage - callback_called = threading.Event() - - def callback(*arg, **argv): - assert threading.current_thread().name == "callback" - callback_called.set() - - stage.on_pipeline_event_handler = callback - - stage.handle_pipeline_event(event) - callback_called.wait() - - -TestPipelineRootStagePipelineThreading.test_runs_callback_in_callback_thread = ( - _test_pipeline_root_runs_callback_in_callback_thread +@pytest.mark.describe( + "PipelineRootStage - .handle_pipeline_event() -- Called with DisconnectedEvent" ) -TestPipelineRootStagePipelineThreading.test_runs_operation_in_pipeline_thread = ( - _test_pipeline_root_runs_operation_in_pipeline_thread -) -TestPipelineRootStagePipelineThreading.test_pipeline_root_runs_on_connected_in_callback_thread = ( - _test_pipeline_root_runs_on_connected_in_callback_thread -) -TestPipelineRootStagePipelineThreading.test_pipeline_root_runs_on_disconnected_in_callback_thread = ( - _test_pipeline_root_runs_on_disconnected_in_callback_thread -) -TestPipelineRootStagePipelineThreading.test_pipeline_root_runs_on_event_received_in_callback_thread = ( - _test_pipeline_root_runs_on_event_received_in_callback_thread +class TestPipelineRootStageHandlePipelineEventWithDisconnectedEvent( + PipelineRootStageTestConfig, StageHandlePipelineEventTestBase +): + @pytest.fixture + def event(self): + return pipeline_events_base.DisconnectedEvent() + + @pytest.mark.it("Sets the 'connected' attribute to True") + def test_set_connected_false(self, stage, event): + stage.connected = True + stage.handle_pipeline_event(event) + assert not stage.connected + + @pytest.mark.it("Invokes the 'on_disconnected_handler' handler function, if set") + def test_invoke_handler(self, mocker, stage, event): + mock_handler = mocker.MagicMock() + stage.on_disconnected_handler = mock_handler + stage.handle_pipeline_event(event) + time.sleep(0.1) # CT-TODO / BK-TODO: get rid of this + assert mock_handler.call_count == 1 + assert mock_handler.call_args == mocker.call() + + +@pytest.mark.describe( + "PipelineRootStage - .handle_pipeline_event() -- Called with an arbitrary other event" ) +class TestPipelineRootStageHandlePipelineEventWithArbitraryEvent( + PipelineRootStageTestConfig, StageHandlePipelineEventTestBase +): + @pytest.fixture + def event(self, arbitrary_event): + return arbitrary_event + + @pytest.mark.it("Invokes the 'on_pipeline_event_handler' handler function, if set") + def test_invoke_handler(self, mocker, stage, event): + mock_handler = mocker.MagicMock() + stage.on_pipeline_event_handler = mock_handler + stage.handle_pipeline_event(event) + time.sleep(0.1) # CT-TODO/BK-TODO: get rid of this + assert mock_handler.call_count == 1 + assert mock_handler.call_args == mocker.call(event) + + +###################### +# AUTO CONNECT STAGE # +###################### + + +class AutoConnectStageTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_stages_base.AutoConnectStage + + @pytest.fixture + def init_kwargs(self, mocker): + return {} + + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + stage.pipeline_root = pipeline_stages_base.PipelineRootStage( + pipeline_configuration=mocker.MagicMock() + ) + # Mock flow methods + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + return stage + pipeline_stage_test.add_base_pipeline_stage_tests( - cls=pipeline_stages_base.EnsureConnectionStage, - module=this_module, - all_ops=all_common_ops, - handled_ops=[ + test_module=this_module, + stage_class_under_test=pipeline_stages_base.AutoConnectStage, + stage_test_config_class=AutoConnectStageTestConfig, +) + + +@pytest.mark.describe( + "AutoConnectStage - .run_op() -- Called with an Operation that requires an active connection" +) +class TestAutoConnectStageRunOpWithOpThatRequiresConnection( + AutoConnectStageTestConfig, StageRunOpTestBase +): + + fake_topic = "__fake_topic__" + fake_payload = "__fake_payload__" + + ops_requiring_connection = [ pipeline_ops_mqtt.MQTTPublishOperation, pipeline_ops_mqtt.MQTTSubscribeOperation, pipeline_ops_mqtt.MQTTUnsubscribeOperation, - ], - all_events=all_common_events, - handled_events=[], -) + ] -fake_topic = "__fake_topic__" -fake_payload = "__fake_payload__" -ops_that_cause_connection = [ - { - "op_class": pipeline_ops_mqtt.MQTTPublishOperation, - "op_init_kwargs": {"topic": fake_topic, "payload": fake_payload}, - }, - {"op_class": pipeline_ops_mqtt.MQTTSubscribeOperation, "op_init_kwargs": {"topic": fake_topic}}, - { - "op_class": pipeline_ops_mqtt.MQTTUnsubscribeOperation, - "op_init_kwargs": {"topic": fake_topic}, - }, -] - - -@pytest.mark.parametrize( - "params", - ops_that_cause_connection, - ids=[x["op_class"].__name__ for x in ops_that_cause_connection], -) -@pytest.mark.describe( - "EnsureConnectionStage - .run_op() -- called with operation that causes a connection to be established" -) -class TestEnsureConnectionStageRunOp(object): - @pytest.fixture - def op(self, mocker, params): - op = params["op_class"](**params["op_init_kwargs"]) - op.callback = mocker.MagicMock() + @pytest.fixture(params=ops_requiring_connection) + def op(self, mocker, request): + op_class = request.param + if op_class is pipeline_ops_mqtt.MQTTPublishOperation: + op = op_class( + topic=self.fake_topic, payload=self.fake_payload, callback=mocker.MagicMock() + ) + else: + op = op_class(topic=self.fake_topic, callback=mocker.MagicMock()) + assert op.needs_connection return op - @pytest.fixture - def stage(self, mocker): - stage = make_mock_stage( - mocker=mocker, stage_to_make=pipeline_stages_base.EnsureConnectionStage - ) - stage.next.run_op = mocker.MagicMock() - return stage - - @pytest.mark.it("Passes the operation down the pipline when the transport is already connected") - def test_operation_alrady_connected(self, params, op, stage): + @pytest.mark.it( + "Sends the operation down the pipeline if the pipeline is already in a 'connected' state" + ) + def test_already_connected(self, mocker, stage, op): stage.pipeline_root.connected = True stage.run_op(op) - assert stage.next.run_op.call_count == 1 - assert stage.next.run_op.call_args[0][0] == op + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) @pytest.mark.it( - "Sends a ConnectOperation instead of the op down the pipeline if the transport is not connected" + "Sends a new ConnectOperation down the pipeline if the pipeline is not yet in a 'connected' state" ) - def test_sends_connect(self, params, op, stage): - stage.pipeline_root.connected = False + def test_not_connected(self, mocker, stage, op): + mock_connect_op = mocker.patch.object(pipeline_ops_base, "ConnectOperation").return_value + assert not stage.pipeline_root.connected stage.run_op(op) - assert stage.next.run_op.call_count == 1 - assert isinstance(stage.next.run_op.call_args[0][0], pipeline_ops_base.ConnectOperation) + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(mock_connect_op) @pytest.mark.it( - "Calls the op's callback with the error from the ConnectOperation if that operation fails" + "Sends the operation down the pipeline once the ConnectOperation completes successfully" ) - def test_connect_failure(self, params, op, stage, fake_exception): - stage.pipeline_root.connected = False + def test_connect_success(self, mocker, stage, op): + assert not stage.pipeline_root.connected + + # Run the original operation + stage.run_op(op) + assert not op.completed + + # Complete the newly created ConnectOperation that was sent down the pipeline + assert stage.send_op_down.call_count == 1 + connect_op = stage.send_op_down.call_args[0][0] + assert isinstance(connect_op, pipeline_ops_base.ConnectOperation) + assert not connect_op.completed + connect_op.complete() # no error + + # The original operation has now been sent down the pipeline + assert stage.send_op_down.call_count == 2 + assert stage.send_op_down.call_args == mocker.call(op) + + @pytest.mark.it( + "Completes the operation with the error from the ConnectOperation, if the ConnectOperation completes with an error" + ) + def test_connect_failure(self, mocker, stage, op, arbitrary_exception): + assert not stage.pipeline_root.connected + + # Run the original operation + stage.run_op(op) + assert not op.completed + + # Complete the newly created ConnectOperation that was sent down the pipeline + assert stage.send_op_down.call_count == 1 + connect_op = stage.send_op_down.call_args[0][0] + assert isinstance(connect_op, pipeline_ops_base.ConnectOperation) + assert not connect_op.completed + connect_op.complete(error=arbitrary_exception) # completes with error + + # The original operation has been completed the exception from the ConnectOperation + assert op.completed + assert op.error is arbitrary_exception + + +@pytest.mark.describe( + "AutoConnectStage - .run_op() -- Called with an Operation that does not require an active connection" +) +class TestAutoConnectStageRunOpWithOpThatDoesNotRequireConnection( + AutoConnectStageTestConfig, StageRunOpTestBase +): + @pytest.fixture + def op(self, arbitrary_op): + assert not arbitrary_op.needs_connection + return arbitrary_op + + @pytest.mark.it( + "Sends the operation down the pipeline if the pipeline is in a 'connected' state" + ) + def test_connected(self, mocker, stage, op): + stage.pipeline_root.connected = True stage.run_op(op) - connect_op = stage.next.run_op.call_args[0][0] - connect_op.error = fake_exception - operation_flow.complete_op(stage=stage.next, op=connect_op) + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) - assert_callback_failed(op=op, error=fake_exception) - - @pytest.mark.it("Waits for the ConnectOperation to complete before pasing the operation down") - def test_connect_success(self, params, op, stage): - stage.pipeline_root.connected = False + @pytest.mark.it( + "Sends the operation down the pipeline if the pipeline is in a 'disconnected' state" + ) + def test_disconnected(self, mocker, stage, op): + assert not stage.pipeline_root.connected stage.run_op(op) - assert stage.next.run_op.call_count == 1 - connect_op = stage.next.run_op.call_args[0][0] - operation_flow.complete_op(stage=stage.next, op=connect_op) + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) - assert stage.next.run_op.call_count == 2 - assert stage.next.run_op.call_args[0][0] == op - @pytest.mark.it("calls the op's callback when the operation is complete after connecting") - def test_operation_complete(self, params, op, stage): - stage.pipeline_root.connected = False +######################### +# CONNECTION LOCK STAGE # +######################### - stage.run_op(op) - connect_op = stage.next.run_op.call_args[0][0] - operation_flow.complete_op(stage=stage.next, op=connect_op) +# This is a list of operations which can trigger a block on the ConnectionLockStage +connection_ops = [ + pipeline_ops_base.ConnectOperation, + pipeline_ops_base.DisconnectOperation, + pipeline_ops_base.ReauthorizeConnectionOperation, +] - operation_flow.complete_op(stage=stage.next, op=op) - assert_callback_succeeded(op=op) - @pytest.mark.it("calls the op's callback when the operation fails after connecting") - def test_operation_fails(self, params, op, stage, fake_exception): - stage.pipeline_root.connected = False +class ConnectionLockStageTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_stages_base.ConnectionLockStage - stage.run_op(op) - connect_op = stage.next.run_op.call_args[0][0] - operation_flow.complete_op(stage=stage.next, op=connect_op) - op.error = fake_exception - operation_flow.complete_op(stage=stage.next, op=op) + @pytest.fixture + def init_kwargs(self, mocker): + return {} - assert_callback_failed(op=op, error=fake_exception) + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + stage.pipeline_root = pipeline_stages_base.PipelineRootStage( + pipeline_configuration=mocker.MagicMock() + ) + stage.send_op_down = mocker.MagicMock() + return stage + + +class ConnectionLockStageInstantiationTests(ConnectionLockStageTestConfig): + @pytest.mark.it("Initializes 'queue' as an empty Queue object") + def test_queue(self, init_kwargs): + stage = pipeline_stages_base.ConnectionLockStage(**init_kwargs) + assert isinstance(stage.queue, queue.Queue) + assert stage.queue.empty() + + @pytest.mark.it("Initializes 'blocked' as False") + def test_blocked(self, init_kwargs): + stage = pipeline_stages_base.ConnectionLockStage(**init_kwargs) + assert not stage.blocked pipeline_stage_test.add_base_pipeline_stage_tests( - cls=pipeline_stages_base.SerializeConnectOpsStage, - module=this_module, - all_ops=all_common_ops, - handled_ops=[ - pipeline_ops_base.ConnectOperation, - pipeline_ops_base.DisconnectOperation, - pipeline_ops_base.ReconnectOperation, - ], - all_events=all_common_events, - handled_events=[], - extra_initializer_defaults={"blocked": False, "queue": queue.Queue}, + test_module=this_module, + stage_class_under_test=pipeline_stages_base.ConnectionLockStage, + stage_test_config_class=ConnectionLockStageTestConfig, + extended_stage_instantiation_test_class=ConnectionLockStageInstantiationTests, ) -connection_ops = [ - {"op_class": pipeline_ops_base.ConnectOperation, "connected_flag_required_to_run": False}, - {"op_class": pipeline_ops_base.DisconnectOperation, "connected_flag_required_to_run": True}, - {"op_class": pipeline_ops_base.ReconnectOperation, "connected_flag_required_to_run": True}, -] - - -class FakeOperation(pipeline_ops_base.PipelineOperation): - pass - @pytest.mark.describe( - "SerializeConnectOpsStage - .run_op() -- called with an operation that connects, disconnects, or reconnects" + "ConnectionLockStage - .run_op() -- Called with a ConnectOperation while not in a blocking state" ) -class TestSerializeConnectOpStageRunOp(object): +class TestConnectionLockStageRunOpWithConnectOpWhileUnblocked( + ConnectionLockStageTestConfig, StageRunOpTestBase +): @pytest.fixture - def stage(self, mocker): - stage = make_mock_stage( - mocker=mocker, stage_to_make=pipeline_stages_base.SerializeConnectOpsStage + def op(self, mocker): + return pipeline_ops_base.ConnectOperation(callback=mocker.MagicMock()) + + @pytest.mark.it("Completes the operation immediately if the pipeline is already connected") + def test_already_connected(self, mocker, stage, op): + stage.pipeline_root.connected = True + + # Run the operation + stage.run_op(op) + + # Operation is completed + assert op.completed + assert op.error is None + + # Stage is still not blocked + assert not stage.blocked + + @pytest.mark.it( + "Puts the stage in a blocking state and sends the operation down the pipeline, if the pipeline is not currently connected" + ) + def test_not_connected(self, mocker, stage, op): + stage.pipeline_root.connected = False + + # Stage is not blocked + assert not stage.blocked + + # Run the operation + stage.run_op(op) + + # Stage is now blocked + assert stage.blocked + + # Operation was passed down + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) + + # Operation is not yet completed + assert not op.completed + + +@pytest.mark.describe( + "ConnectionLockStage - .run_op() -- Called with a DisconnectOperation while not in a blocking state" +) +class TestConnectionLockStageRunOpWithDisconnectOpWhileUnblocked( + ConnectionLockStageTestConfig, StageRunOpTestBase +): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_base.DisconnectOperation(callback=mocker.MagicMock()) + + @pytest.mark.it("Completes the operation immediately if the pipeline is already disconnected") + def test_already_disconnected(self, mocker, stage, op): + stage.pipeline_root.connected = False + + # Run the operation + stage.run_op(op) + + # Operation is completed + assert op.completed + assert op.error is None + + # Stage is still not blocked + assert not stage.blocked + + @pytest.mark.it( + "Puts the stage in a blocking state and sends the operation down the pipeline, if the pipeline is currently connected" + ) + def test_connected(self, mocker, stage, op): + stage.pipeline_root.connected = True + + # Stage is not blocked + assert not stage.blocked + + # Run the operation + stage.run_op(op) + + # Stage is now blocked + assert stage.blocked + + # Operation was passed down + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) + + # Operation is not yet completed + assert not op.completed + + +@pytest.mark.describe( + "ConnectionLockStage - .run_op() -- Called with a ReauthorizeConnectionOperation while not in a blocking state" +) +class TestConnectionLockStageRunOpWithReconnectOpWhileUnblocked( + ConnectionLockStageTestConfig, StageRunOpTestBase +): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_base.ReauthorizeConnectionOperation(callback=mocker.MagicMock()) + + @pytest.mark.it("Puts the stage in a blocking state and sends the operation down the pipeline") + @pytest.mark.parametrize( + "connected", + [ + pytest.param(True, id="Pipeline Connected"), + pytest.param(False, id="Pipeline Disconnected"), + ], + ) + def test_not_connected(self, mocker, connected, stage, op): + stage.pipeline_root.connected = connected + + # Stage is not blocked + assert not stage.blocked + + # Run the operation + stage.run_op(op) + + # Stage is now blocked + assert stage.blocked + + # Operation was passed down + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) + + # Operation is not yet completed + assert not op.completed + + +@pytest.mark.describe( + "ConnectionLockStage - .run_op() -- Called with an arbitrary other operation while not in a blocking state" +) +class TestConnectionLockStageRunOpWithArbitraryOpWhileUnblocked( + ConnectionLockStageTestConfig, StageRunOpTestBase +): + @pytest.fixture + def op(self, arbitrary_op): + return arbitrary_op + + @pytest.mark.it("Sends the operation down the pipeline") + @pytest.mark.parametrize( + "connected", + [ + pytest.param(True, id="Pipeline Connected"), + pytest.param(False, id="Pipeline Disconnected"), + ], + ) + def test_sends_down(self, mocker, connected, stage, op): + stage.pipeline_root.connected = connected + + stage.run_op(op) + + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) + + +@pytest.mark.describe("ConnectionLockStage - .run_op() -- Called while in a blocking state") +class TestConnectionLockStageRunOpWhileBlocked(ConnectionLockStageTestConfig, StageRunOpTestBase): + @pytest.fixture + def blocking_op(self, mocker): + return pipeline_ops_base.ConnectOperation(callback=mocker.MagicMock()) + + @pytest.fixture + def stage(self, mocker, init_kwargs, blocking_op): + stage = pipeline_stages_base.ConnectionLockStage(**init_kwargs) + stage.pipeline_root = pipeline_stages_base.PipelineRootStage( + pipeline_configuration=mocker.MagicMock() ) - stage.next.run_op = mocker.MagicMock() + stage.send_op_down = mocker.MagicMock() + mocker.spy(stage, "run_op") + assert not stage.blocked + + # Block the stage by running a blocking operation + stage.run_op(blocking_op) + assert stage.blocked + + # Reset the mock for ease of testing + stage.send_op_down.reset_mock() + stage.run_op.reset_mock() + return stage + + @pytest.fixture(params=(connection_ops + [ArbitraryOperation])) + def op(self, mocker, request): + conn_op_class = request.param + op = conn_op_class(callback=mocker.MagicMock()) + return op + + @pytest.mark.it( + "Adds the operation to the queue, pending the completion of the operation on which the stage is blocked" + ) + def test_adds_to_queue(self, mocker, stage, op): + assert stage.queue.empty() + stage.run_op(op) + + # Operation is in queue + assert not stage.queue.empty() + assert stage.queue.qsize() == 1 + assert stage.queue.get(block=False) is op + + # Operation was not passed down + assert stage.send_op_down.call_count == 0 + + # Operation has not been completed + assert not op.completed + + @pytest.mark.it( + "Adds the operation to the queue, even if the operation's desired pipeline connection state already has been reached" + ) + @pytest.mark.parametrize( + "op", + [pipeline_ops_base.ConnectOperation, pipeline_ops_base.DisconnectOperation], + indirect=True, + ) + def test_blocks_ops_ready_for_completion(self, mocker, stage, op): + # Set the pipeline connection state to be the one desired by the operation. + # If the stage were unblocked, this would lead to immediate completion of the op. + if isinstance(op, pipeline_ops_base.ConnectOperation): + stage.pipeline_root.connected = True + else: + stage.pipeline_root.connected = False + + assert stage.queue.empty() + + stage.run_op(op) + + assert not op.completed + assert stage.queue.qsize() == 1 + assert stage.send_op_down.call_count == 0 + + @pytest.mark.it( + "Can support multiple pending operations if called multiple times during the blocking state" + ) + def test_multiple_ops_added_to_queue(self, mocker, stage): + assert stage.queue.empty() + + op1 = pipeline_ops_base.DisconnectOperation(callback=mocker.MagicMock()) + op2 = pipeline_ops_base.ConnectOperation(callback=mocker.MagicMock()) + op3 = pipeline_ops_base.ReauthorizeConnectionOperation(callback=mocker.MagicMock()) + op4 = ArbitraryOperation(callback=mocker.MagicMock()) + + stage.run_op(op1) + stage.run_op(op2) + stage.run_op(op3) + stage.run_op(op4) + + # Operations have all been added to the queue + assert not stage.queue.empty() + assert stage.queue.qsize() == 4 + + # No Operations were passed down + assert stage.send_op_down.call_count == 0 + + # No Operations have been completed + assert not op1.completed + assert not op2.completed + assert not op3.completed + assert not op4.completed + + +class ConnectionLockStageBlockingOpCompletedTestConfig(ConnectionLockStageTestConfig): + @pytest.fixture(params=connection_ops) + def blocking_op(self, mocker, request): + op_cls = request.param + return op_cls(callback=mocker.MagicMock()) + + @pytest.fixture + def pending_ops(self, mocker): + op1 = ArbitraryOperation(callback=mocker.MagicMock) + op2 = ArbitraryOperation(callback=mocker.MagicMock) + op3 = ArbitraryOperation(callback=mocker.MagicMock) + pending_ops = [op1, op2, op3] + return pending_ops + + @pytest.fixture + def blocked_stage(self, mocker, init_kwargs, blocking_op, pending_ops): + stage = pipeline_stages_base.ConnectionLockStage(**init_kwargs) + stage.pipeline_root = pipeline_stages_base.PipelineRootStage( + pipeline_configuration=mocker.MagicMock() + ) + stage.send_op_down = mocker.MagicMock() + mocker.spy(stage, "run_op") + assert not stage.blocked + + # Set the pipeline connection state to ensure op will block + if isinstance(blocking_op, pipeline_ops_base.ConnectOperation): + stage.pipeline_root.connected = False + else: + stage.pipeline_root.connected = True + + # Block the stage by running the blocking operation + stage.run_op(blocking_op) + assert stage.blocked + + # Add pending operations + for op in pending_ops: + stage.run_op(op) + + # All pending ops should be queued + assert stage.queue.qsize() == len(pending_ops) + + # Reset the mock for ease of testing + stage.send_op_down.reset_mock() + stage.run_op.reset_mock() + return stage + + +@pytest.mark.describe( + "ConnectionLockStage - OCCURANCE: Operation blocking ConnectionLockStage is completed successfully" +) +class TestConnectionLockStageBlockingOpCompletedNoError( + ConnectionLockStageBlockingOpCompletedTestConfig +): + @pytest.mark.it("Re-runs the pending operations in FIFO order") + def test_blocking_op_completes_successfully( + self, mocker, blocked_stage, pending_ops, blocking_op + ): + stage = blocked_stage + # .run_op() has not yet been called + assert stage.run_op.call_count == 0 + + # Pending ops are queued in the stage + assert stage.queue.qsize() == len(pending_ops) + + # Complete blocking op successfully + blocking_op.complete() + + # .run_op() was called for every pending operation, in FIFO order + assert stage.run_op.call_count == len(pending_ops) + assert stage.run_op.call_args_list == [mocker.call(op) for op in pending_ops] + + # Note that this is only true because we are using arbitrary ops. Depending on what occurs during + # the .run_op() calls, this could end up having items, but that case is covered by a different test + assert stage.queue.qsize() == 0 + + @pytest.mark.it("Unblocks the ConnectionLockStage prior to re-running any pending operations") + def test_unblocks_before_rerun(self, mocker, blocked_stage, blocking_op, pending_ops): + stage = blocked_stage + mocker.spy(handle_exceptions, "handle_background_exception") + assert stage.blocked + + def run_op_override(op): + # Because the .run_op() invocation is called during operation completion, + # any exceptions, including AssertionErrors will go to the background exception handler + + # Verify that the stage is not blocked during the call to .run_op() + assert not stage.blocked + + stage.run_op = mocker.MagicMock(side_effect=run_op_override) + + blocking_op.complete() + + # Stage is still unblocked by the end of the blocking op completion + assert not stage.blocked + + # Verify that the mock .run_op() was indeed called + assert stage.run_op.call_count == len(pending_ops) + + # Verify that no assertions from the mock .run_op() turned up False + assert handle_exceptions.handle_background_exception.call_count == 0 + + @pytest.mark.it( + "Requeues subsequent operations, retaining their original order, if one of the re-run operations returns the ConnectionLockStage to a blocking state" + ) + def test_unblocked_op_changes_block_state(self, mocker, stage): + op1 = pipeline_ops_base.ConnectOperation(callback=mocker.MagicMock()) + op2 = ArbitraryOperation(callback=mocker.MagicMock()) + op3 = pipeline_ops_base.ReauthorizeConnectionOperation(callback=mocker.MagicMock()) + op4 = ArbitraryOperation(callback=mocker.MagicMock()) + op5 = ArbitraryOperation(callback=mocker.MagicMock()) + + # Block the stage on op1 + assert not stage.pipeline_root.connected + assert not stage.blocked + stage.run_op(op1) + assert stage.blocked + assert stage.queue.qsize() == 0 + + # Run the rest of the ops, which will be added to the queue + stage.run_op(op2) + stage.run_op(op3) + stage.run_op(op4) + stage.run_op(op5) + + # op1 is the only op that has been passed down so far + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op1) + assert stage.queue.qsize() == 4 + + # Complete op1 + op1.complete() + + # Manually set pipeline to be connected (this doesn't happen naturally due to the scope of this test) + stage.pipeline_root.connected = True + + # op2 and op3 have now been passed down, but no others + assert stage.send_op_down.call_count == 3 + assert stage.send_op_down.call_args_list[1] == mocker.call(op2) + assert stage.send_op_down.call_args_list[2] == mocker.call(op3) + assert stage.queue.qsize() == 2 + + # Complete op3 + op3.complete() + + # op4 and op5 are now also passed down + assert stage.send_op_down.call_count == 5 + assert stage.send_op_down.call_args_list[3] == mocker.call(op4) + assert stage.send_op_down.call_args_list[4] == mocker.call(op5) + assert stage.queue.qsize() == 0 + + +@pytest.mark.describe( + "ConnectionLockStage - OCCURANCE: Operation blocking ConnectionLockStage is completed with error" +) +class TestConnectionLockStageBlockingOpCompletedWithError( + ConnectionLockStageBlockingOpCompletedTestConfig +): + # CT-TODO: Show that completion occurs in FIFO order + @pytest.mark.it("Completes all pending operations with the error from the blocking operation") + def test_blocking_op_completes_with_error( + self, blocked_stage, pending_ops, blocking_op, arbitrary_exception + ): + stage = blocked_stage + + # Pending ops are not yet completed + for op in pending_ops: + assert not op.completed + + # Pending ops are queued in the stage + assert stage.queue.qsize() == len(pending_ops) + + # Complete blocking op with error + blocking_op.complete(error=arbitrary_exception) + + # Pending ops are now completed with error from blocking op + for op in pending_ops: + assert op.completed + assert op.error is arbitrary_exception + + # No more pending ops in stage queue + assert stage.queue.empty() + + @pytest.mark.it("Unblocks the ConnectionLockStage prior to completing any pending operations") + def test_unblocks_before_complete( + self, mocker, blocked_stage, pending_ops, blocking_op, arbitrary_exception + ): + stage = blocked_stage + mocker.spy(handle_exceptions, "handle_background_exception") + assert stage.blocked + + def complete_override(error=None): + # Because this call to .complete() is called during another op's completion, + # any exceptions, including AssertionErrors will go to the background exception handler + + # Verify that the stage is not blocked during the call to .complete() + assert not stage.blocked + + for op in pending_ops: + op.complete = mocker.MagicMock(side_effect=complete_override) + + # Complete the blocking op with error + blocking_op.complete(error=arbitrary_exception) + + # Stage is still unblocked at the end of the blocking op completion + assert not stage.blocked + + # Verify that the mock completion was called for the pending ops + for op in pending_ops: + assert op.complete.call_count == 1 + + # Verify that no assertions from the mock .complete() calls turned up False + assert handle_exceptions.handle_background_exception.call_count == 0 + + +######################################### +# COORDINATE REQUEST AND RESPONSE STAGE # +######################################### + + +@pytest.fixture +def fake_uuid(mocker): + my_uuid = "0f4f876b-f445-432e-a8de-43bbd66e4668" + uuid4_mock = mocker.patch.object(uuid, "uuid4") + uuid4_mock.return_value.__str__.return_value = my_uuid + return my_uuid + + +class CoordinateRequestAndResponseStageTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_stages_base.CoordinateRequestAndResponseStage + + @pytest.fixture + def init_kwargs(self, mocker): + return {} + + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + stage.pipeline_root = pipeline_stages_base.PipelineRootStage( + pipeline_configuration=mocker.MagicMock() + ) + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + return stage + + +class CoordinateRequestAndResponseStageInstantiationTests( + CoordinateRequestAndResponseStageTestConfig +): + @pytest.mark.it("Initializes 'pending_responses' as an empty dict") + def test_pending_responses(self, init_kwargs): + stage = pipeline_stages_base.CoordinateRequestAndResponseStage(**init_kwargs) + assert stage.pending_responses == {} + + +pipeline_stage_test.add_base_pipeline_stage_tests( + test_module=this_module, + stage_class_under_test=pipeline_stages_base.CoordinateRequestAndResponseStage, + stage_test_config_class=CoordinateRequestAndResponseStageTestConfig, + extended_stage_instantiation_test_class=CoordinateRequestAndResponseStageInstantiationTests, +) + + +@pytest.mark.describe( + "CoordinateRequestAndResponseStage - .run_op() -- Called with a RequestAndResponseOperation" +) +class TestCoordinateRequestAndResponseStageRunOpWithRequestAndResponseOperation( + CoordinateRequestAndResponseStageTestConfig, StageRunOpTestBase +): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_base.RequestAndResponseOperation( + request_type="some_request_type", + method="SOME_METHOD", + resource_location="some/resource/location", + request_body="some_request_body", + callback=mocker.MagicMock(), + ) + + @pytest.mark.it( + "Stores the operation in the 'pending_responses' dictionary, mapped with a generated UUID" + ) + def test_stores_op(self, mocker, stage, op, fake_uuid): + stage.run_op(op) + + assert stage.pending_responses[fake_uuid] is op + assert not op.completed + + @pytest.mark.it( + "Creates and a new RequestOperation using the generated UUID and sends it down the pipeline" + ) + def test_sends_down_new_request_op(self, mocker, stage, op, fake_uuid): + stage.run_op(op) + + assert stage.send_op_down.call_count == 1 + request_op = stage.send_op_down.call_args[0][0] + assert isinstance(request_op, pipeline_ops_base.RequestOperation) + assert request_op.method == op.method + assert request_op.resource_location == op.resource_location + assert request_op.request_body == op.request_body + assert request_op.request_type == op.request_type + assert request_op.request_id == fake_uuid + + @pytest.mark.it( + "Generates a unique UUID for each RequestAndResponseOperation/RequestOperation pair" + ) + def test_unique_uuid(self, mocker, stage, op): + op1 = op + op2 = copy.deepcopy(op) + op3 = copy.deepcopy(op) + + stage.run_op(op1) + assert stage.send_op_down.call_count == 1 + uuid1 = stage.send_op_down.call_args[0][0].request_id + stage.run_op(op2) + assert stage.send_op_down.call_count == 2 + uuid2 = stage.send_op_down.call_args[0][0].request_id + stage.run_op(op3) + assert stage.send_op_down.call_count == 3 + uuid3 = stage.send_op_down.call_args[0][0].request_id + + assert uuid1 != uuid2 != uuid3 + assert stage.pending_responses[uuid1] is op1 + assert stage.pending_responses[uuid2] is op2 + assert stage.pending_responses[uuid3] is op3 + + +@pytest.mark.describe( + "CoordinateRequestAndResponseStage - .run_op() -- Called with an arbitrary other operation" +) +class TestCoordinateRequestAndResponseStageRunOpWithArbitraryOperation( + CoordinateRequestAndResponseStageTestConfig, StageRunOpTestBase +): + @pytest.fixture + def op(self, arbitrary_op): + return arbitrary_op + + @pytest.mark.it("Sends the operation down the pipeline") + def test_sends_down(self, stage, mocker, op): + stage.run_op(op) + + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) + + +@pytest.mark.describe( + "CoordinateRequestAndResponseStage - OCCURANCE: RequestOperation tied to a stored RequestAndResponseOperation is completed" +) +class TestCoordinateRequestAndResponseStageRequestOperationCompleted( + CoordinateRequestAndResponseStageTestConfig +): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_base.RequestAndResponseOperation( + request_type="some_request_type", + method="SOME_METHOD", + resource_location="some/resource/location", + request_body="some_request_body", + callback=mocker.MagicMock(), + ) + + @pytest.mark.it( + "Completes the associated RequestAndResponseOperation with the error from the RequestOperation and removes it from the 'pending_responses' dict, if the RequestOperation is completed unsuccessfully" + ) + def test_request_completed_with_error(self, mocker, stage, op, arbitrary_exception): + stage.run_op(op) + request_op = stage.send_op_down.call_args[0][0] + + assert not op.completed + assert not request_op.completed + assert stage.pending_responses[request_op.request_id] is op + + request_op.complete(error=arbitrary_exception) + + # RequestAndResponseOperation has been completed with the error from the RequestOperation + assert request_op.completed + assert op.completed + assert op.error is request_op.error is arbitrary_exception + + # RequestAndResponseOperation has been removed from the 'pending_responses' dict + with pytest.raises(KeyError): + stage.pending_responses[request_op.request_id] + + @pytest.mark.it( + "Does not complete or remove the RequestAndResponseOperation from the 'pending_responses' dict if the RequestOperation is completed successfully" + ) + def test_request_completed_successfully(self, mocker, stage, op, arbitrary_exception): + stage.run_op(op) + request_op = stage.send_op_down.call_args[0][0] + + request_op.complete() + + assert request_op.completed + assert not op.completed + assert stage.pending_responses[request_op.request_id] is op + + +@pytest.mark.describe( + "CoordinateRequestAndResponseStage - .handle_pipeline_event() -- Called with ResponseEvent" +) +class TestCoordinateRequestAndResponseStageHandlePipelineEventWithResponseEvent( + CoordinateRequestAndResponseStageTestConfig, StageHandlePipelineEventTestBase +): + @pytest.fixture + def event(self, fake_uuid): + return pipeline_events_base.ResponseEvent( + request_id=fake_uuid, status_code=200, response_body="response body" + ) + + @pytest.fixture + def pending_op(self, mocker): + return pipeline_ops_base.RequestAndResponseOperation( + request_type="some_request_type", + method="SOME_METHOD", + resource_location="some/resource/location", + request_body="some_request_body", + callback=mocker.MagicMock(), + ) + + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs, fake_uuid, pending_op): + stage = cls_type(**init_kwargs) + stage.pipeline_root = pipeline_stages_base.PipelineRootStage( + pipeline_configuration=mocker.MagicMock() + ) + stage.send_event_up = mocker.MagicMock() + stage.send_op_down = mocker.MagicMock() + + # Run the pending op + stage.run_op(pending_op) + return stage + + @pytest.mark.it( + "Successfully completes a pending RequestAndResponseOperation that matches the 'request_id' of the ResponseEvent, and removes it from the 'pending_responses' dictionary" + ) + def test_completes_matching_request_and_response_operation( + self, mocker, stage, pending_op, event, fake_uuid + ): + assert stage.pending_responses[fake_uuid] is pending_op + assert not pending_op.completed + + # Handle the ResponseEvent + assert event.request_id == fake_uuid + stage.handle_pipeline_event(event) + + # The pending RequestAndResponseOperation is complete + assert pending_op.completed + + # The RequestAndResponseOperation has been removed from the dictionary + with pytest.raises(KeyError): + stage.pending_responses[fake_uuid] + + @pytest.mark.it( + "Sets the 'status_code' and 'response_body' attributes on the completed RequestAndResponseOperation with values from the ResponseEvent" + ) + def test_returns_values_in_attributes(self, mocker, stage, pending_op, event): + assert not pending_op.completed + assert pending_op.status_code is None + assert pending_op.response_body is None + + stage.handle_pipeline_event(event) + + assert pending_op.completed + assert pending_op.status_code == event.status_code + assert pending_op.response_body == event.response_body + + @pytest.mark.it( + "Does nothing if there is no pending RequestAndResponseOperation that matches the 'request_id' of the ResponseEvent" + ) + def test_no_matching_request_id(self, mocker, stage, pending_op, event, fake_uuid): + assert stage.pending_responses[fake_uuid] is pending_op + assert not pending_op.completed + + # Use a nonmatching UUID + event.request_id = "non-matching-uuid" + assert event.request_id != fake_uuid + stage.handle_pipeline_event(event) + + # Nothing has changed + assert stage.pending_responses[fake_uuid] is pending_op + assert not pending_op.completed + + +@pytest.mark.describe( + "CoordinateRequestAndResponseStage - .handle_pipeline_event() -- Called with arbitrary other event" +) +class TestCoordinateRequestAndResponseStageHandlePipelineEventWithArbitraryEvent( + CoordinateRequestAndResponseStageTestConfig, StageHandlePipelineEventTestBase +): + @pytest.fixture + def event(self, arbitrary_event): + return arbitrary_event + + @pytest.mark.it("Sends the event up the pipeline") + def test_sends_up(self, mocker, stage, event): + stage.handle_pipeline_event(event) + + assert stage.send_event_up.call_count == 1 + assert stage.send_event_up.call_args == mocker.call(event) + + +#################### +# OP TIMEOUT STAGE # +#################### + +ops_that_time_out = [ + pipeline_ops_mqtt.MQTTSubscribeOperation, + pipeline_ops_mqtt.MQTTUnsubscribeOperation, +] + + +class OpTimeoutStageTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_stages_base.OpTimeoutStage + + @pytest.fixture + def init_kwargs(self, mocker): + return {} + + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + stage.pipeline_root = pipeline_stages_base.PipelineRootStage( + pipeline_configuration=mocker.MagicMock() + ) + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + return stage + + +class OpTimeoutStageInstantiationTests(OpTimeoutStageTestConfig): + # TODO: this will no longer be necessary once these are implemented as part of a more robust retry policy + @pytest.mark.it( + "Sets default timout intervals to 10 seconds for MQTTSubscribeOperation and MQTTUnsubscribeOperation" + ) + def test_timeout_intervals(self, init_kwargs): + stage = pipeline_stages_base.OpTimeoutStage(**init_kwargs) + assert stage.timeout_intervals[pipeline_ops_mqtt.MQTTSubscribeOperation] == 10 + assert stage.timeout_intervals[pipeline_ops_mqtt.MQTTUnsubscribeOperation] == 10 + + +pipeline_stage_test.add_base_pipeline_stage_tests( + test_module=this_module, + stage_class_under_test=pipeline_stages_base.OpTimeoutStage, + stage_test_config_class=OpTimeoutStageTestConfig, + extended_stage_instantiation_test_class=OpTimeoutStageInstantiationTests, +) + + +@pytest.mark.describe("OpTimeoutStage - .run_op() -- Called with operation eligible for timeout") +class TestOpTimeoutStageRunOpCalledWithOpThatCanTimeout( + OpTimeoutStageTestConfig, StageRunOpTestBase +): + @pytest.fixture(params=ops_that_time_out) + def op(self, mocker, request): + op_cls = request.param + op = op_cls(topic="some/topic", callback=mocker.MagicMock()) + return op + + @pytest.mark.it( + "Adds a timeout timer with the interval specified in the configuration to the operation, and starts it" + ) + def test_adds_timer(self, mocker, stage, op, mock_timer): + + stage.run_op(op) + + assert mock_timer.call_count == 1 + assert mock_timer.call_args == mocker.call(stage.timeout_intervals[type(op)], mocker.ANY) + assert op.timeout_timer is mock_timer.return_value + assert op.timeout_timer.start.call_count == 1 + assert op.timeout_timer.start.call_args == mocker.call() + + @pytest.mark.it("Sends the operation down the pipeline") + def test_sends_down(self, mocker, stage, op, mock_timer): + stage.run_op(op) + + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) + assert op.timeout_timer is mock_timer.return_value + + +@pytest.mark.describe( + "OpTimeoutStage - .run_op() -- Called with arbitrary operation that is not eligible for timeout" +) +class TestOpTimeoutStageRunOpCalledWithOpThatDoesNotTimeout( + OpTimeoutStageTestConfig, StageRunOpTestBase +): + @pytest.fixture + def op(self, arbitrary_op): + return arbitrary_op + + @pytest.mark.it("Sends the operation down the pipeline without attaching a timeout timer") + def test_sends_down(self, mocker, stage, op, mock_timer): + stage.run_op(op) + + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) + assert mock_timer.call_count == 0 + assert not hasattr(op, "timeout_timer") + + +@pytest.mark.describe( + "OpTimeoutStage - OCCURANCE: Operation with a timeout timer times out before completion" +) +class TestOpTimeoutStageOpTimesOut(OpTimeoutStageTestConfig): + @pytest.fixture(params=ops_that_time_out) + def op(self, mocker, request): + op_cls = request.param + op = op_cls(topic="some/topic", callback=mocker.MagicMock()) + return op + + @pytest.mark.it("Completes the operation unsuccessfully, with a PiplineTimeoutError") + def test_pipeline_timeout(self, mocker, stage, op, mock_timer): + # Apply the timer + stage.run_op(op) + assert not op.completed + assert mock_timer.call_count == 1 + on_timer_complete = mock_timer.call_args[0][1] + + # Call timer complete callback (indicating timer completion) + on_timer_complete() + + # Op is now completed with error + assert op.completed + assert isinstance(op.error, pipeline_exceptions.PipelineTimeoutError) + + +@pytest.mark.describe( + "OpTimeoutStage - OCCURANCE: Operation with a timeout timer completes before timeout" +) +class TestOpTimeoutStageOpCompletesBeforeTimeout(OpTimeoutStageTestConfig): + @pytest.fixture(params=ops_that_time_out) + def op(self, mocker, request): + op_cls = request.param + op = op_cls(topic="some/topic", callback=mocker.MagicMock()) + return op + + @pytest.mark.it("Cancels and clears the operation's timeout timer") + def test_complete_before_timeout(self, mocker, stage, op, mock_timer): + # Apply the timer + stage.run_op(op) + assert not op.completed + assert mock_timer.call_count == 1 + mock_timer_inst = op.timeout_timer + assert mock_timer_inst is mock_timer.return_value + assert mock_timer_inst.cancel.call_count == 0 + + # Complete the operation + op.complete() + + # Timer is now cancelled and cleared + assert mock_timer_inst.cancel.call_count == 1 + assert mock_timer_inst.cancel.call_args == mocker.call() + assert op.timeout_timer is None + + +############### +# RETRY STAGE # +############### + +# Tuples of classname + args +retryable_ops = [ + (pipeline_ops_mqtt.MQTTSubscribeOperation, {"topic": "fake_topic", "callback": fake_callback}), + ( + pipeline_ops_mqtt.MQTTUnsubscribeOperation, + {"topic": "fake_topic", "callback": fake_callback}, + ), + ( + pipeline_ops_mqtt.MQTTPublishOperation, + {"topic": "fake_topic", "payload": "fake_payload", "callback": fake_callback}, + ), +] + +retryable_exceptions = [pipeline_exceptions.PipelineTimeoutError] + + +class RetryStageTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_stages_base.RetryStage + + @pytest.fixture + def init_kwargs(self, mocker): + return {} + + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + stage.pipeline_root = pipeline_stages_base.PipelineRootStage( + pipeline_configuration=mocker.MagicMock() + ) + mocker.spy(stage, "run_op") + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + return stage + + +class RetryStageInstantiationTests(RetryStageTestConfig): + # TODO: this will no longer be necessary once these are implemented as part of a more robust retry policy + @pytest.mark.it( + "Sets default retry intervals to 20 seconds for MQTTSubscribeOperation, MQTTUnsubscribeOperation, and MQTTPublishOperation" + ) + def test_retry_intervals(self, init_kwargs): + stage = pipeline_stages_base.RetryStage(**init_kwargs) + assert stage.retry_intervals[pipeline_ops_mqtt.MQTTSubscribeOperation] == 20 + assert stage.retry_intervals[pipeline_ops_mqtt.MQTTUnsubscribeOperation] == 20 + assert stage.retry_intervals[pipeline_ops_mqtt.MQTTPublishOperation] == 20 + + @pytest.mark.it("Initializes 'ops_waiting_to_retry' as an empty list") + def test_ops_waiting_to_retry(self, init_kwargs): + stage = pipeline_stages_base.RetryStage(**init_kwargs) + assert stage.ops_waiting_to_retry == [] + + +pipeline_stage_test.add_base_pipeline_stage_tests( + test_module=this_module, + stage_class_under_test=pipeline_stages_base.RetryStage, + stage_test_config_class=RetryStageTestConfig, + extended_stage_instantiation_test_class=RetryStageInstantiationTests, +) + + +# NOTE: Although there is a branch in the implementation that distinguishes between +# retryable operations, and non-retryable operations, with retryable operations having +# a callback added, this is not captured in this test, as callback resolution is tested +# in a different unit. +@pytest.mark.describe("RetryStage - .run_op()") +class TestRetryStageRunOp(RetryStageTestConfig, StageRunOpTestBase): + ops = retryable_ops + [(ArbitraryOperation, {"callback": fake_callback})] + + @pytest.fixture(params=ops, ids=[x[0].__name__ for x in ops]) + def op(self, request, mocker): + op_cls = request.param[0] + init_kwargs = request.param[1] + return op_cls(**init_kwargs) + + @pytest.mark.it("Sends the operation down the pipeline") + def test_sends_op_down(self, mocker, stage, op): + stage.run_op(op) + + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) + + +@pytest.mark.describe( + "RetryStage - OCCURANCE: Retryable operation completes unsuccessfully with a retryable error after call to .run_op()" +) +class TestRetryStageRetryableOperationCompletedWithRetryableError(RetryStageTestConfig): + @pytest.fixture(params=retryable_ops, ids=[x[0].__name__ for x in retryable_ops]) + def op(self, request, mocker): + op_cls = request.param[0] + init_kwargs = request.param[1] + return op_cls(**init_kwargs) + + @pytest.fixture(params=retryable_exceptions) + def error(self, request): + return request.param() + + @pytest.mark.it("Halts operation completion") + def test_halt(self, mocker, stage, op, error, mock_timer): + stage.run_op(op) + op.complete(error=error) + + assert not op.completed + + @pytest.mark.it( + "Adds a retry timer to the operation with the interval specified for the operation by the configuration, and starts it" + ) + def test_timer(self, mocker, stage, op, error, mock_timer): + stage.run_op(op) + op.complete(error=error) + + assert mock_timer.call_count == 1 + assert mock_timer.call_args == mocker.call(stage.retry_intervals[type(op)], mocker.ANY) + assert op.retry_timer is mock_timer.return_value + assert op.retry_timer.start.call_count == 1 + assert op.retry_timer.start.call_args == mocker.call() + + @pytest.mark.it( + "Adds the operation to the list of 'ops_waiting_to_retry' only for the duration of the timer" + ) + def test_adds_to_waiting_list_during_timer(self, mocker, stage, op, error, mock_timer): + stage.run_op(op) + + # The op is not listed as waiting for retry before completion + assert op not in stage.ops_waiting_to_retry + + # Completing the op starts the timer + op.complete(error=error) + assert mock_timer.call_count == 1 + timer_callback = mock_timer.call_args[0][1] + assert mock_timer.return_value.start.call_count == 1 + + # Once completed and the timer has been started, the op IS listed as waiting for retry + assert op in stage.ops_waiting_to_retry + + # Simulate timer completion + timer_callback() + + # Once the timer is completed, the op is no longer listed as waiting for retry + assert op not in stage.ops_waiting_to_retry + + @pytest.mark.it("Re-runs the operation after the retry timer expires") + def test_reruns(self, mocker, stage, op, error, mock_timer): + stage.run_op(op) + op.complete(error=error) + + assert stage.run_op.call_count == 1 + assert mock_timer.call_count == 1 + timer_callback = mock_timer.call_args[0][1] + + # Simulate timer completion + timer_callback() + + # run_op was called again + assert stage.run_op.call_count == 2 + + @pytest.mark.it("Cancels and clears the retry timer after the retry timer expires") + def test_clears_retry_timer(self, mocker, stage, op, error, mock_timer): + stage.run_op(op) + op.complete(error=error) + timer_callback = mock_timer.call_args[0][1] + + assert mock_timer.cancel.call_count == 0 + assert op.retry_timer is mock_timer.return_value + + # Simulate timer completion + timer_callback() + + assert mock_timer.return_value.cancel.call_count == 1 + assert mock_timer.return_value.cancel.call_args == mocker.call() + assert op.retry_timer is None + + @pytest.mark.it( + "Adds a new retry timer to the re-run operation, if it completes unsuccessfully again" + ) + def test_rerun_op_unsuccessful_again(self, mocker, stage, op, error, mock_timer): + stage.run_op(op) + assert stage.run_op.call_count == 1 + + # Complete with failure the first time + op.complete(error=error) + + assert mock_timer.call_count == 1 + assert op.retry_timer is mock_timer.return_value + timer_callback1 = mock_timer.call_args[0][1] + + # Trigger retry + timer_callback1() + + assert stage.run_op.call_count == 2 + assert stage.run_op.call_args == mocker.call(op) + assert op.retry_timer is None + + # Complete with failure the second time + op.complete(error=error) + + assert mock_timer.call_count == 2 + assert op.retry_timer is mock_timer.return_value + timer_callback2 = mock_timer.call_args[0][1] + + # Trigger retry again + timer_callback2() + + assert stage.run_op.call_count == 3 + assert stage.run_op.call_args == mocker.call(op) + assert op.retry_timer is None + + @pytest.mark.it("Supports multiple simultaneous operations retrying") + def test_multiple_retries(self, mocker, stage, mock_timer): + op1 = pipeline_ops_mqtt.MQTTSubscribeOperation( + topic="fake_topic", callback=mocker.MagicMock() + ) + op2 = pipeline_ops_mqtt.MQTTPublishOperation( + topic="fake_topic", payload="fake_payload", callback=mocker.MagicMock() + ) + op3 = pipeline_ops_mqtt.MQTTUnsubscribeOperation( + topic="fake_topic", callback=mocker.MagicMock() + ) + + stage.run_op(op1) + stage.run_op(op2) + stage.run_op(op3) + assert stage.run_op.call_count == 3 + + assert not op1.completed + assert not op2.completed + assert not op3.completed + + op1.complete(error=pipeline_exceptions.PipelineTimeoutError()) + op2.complete(error=pipeline_exceptions.PipelineTimeoutError()) + op3.complete(error=pipeline_exceptions.PipelineTimeoutError()) + + # Ops halted + assert not op1.completed + assert not op2.completed + assert not op3.completed + + # Timers set + assert mock_timer.call_count == 3 + assert op1.retry_timer is mock_timer.return_value + assert op2.retry_timer is mock_timer.return_value + assert op3.retry_timer is mock_timer.return_value + assert mock_timer.return_value.start.call_count == 3 + + # Operations awaiting retry + assert op1 in stage.ops_waiting_to_retry + assert op2 in stage.ops_waiting_to_retry + assert op3 in stage.ops_waiting_to_retry + + timer1_complete = mock_timer.call_args_list[0][0][1] + timer2_complete = mock_timer.call_args_list[1][0][1] + timer3_complete = mock_timer.call_args_list[2][0][1] + + # Trigger op1's timer to complete + timer1_complete() + + # Only op1 was re-run, and had it's timer removed + assert mock_timer.return_value.cancel.call_count == 1 + assert op1.retry_timer is None + assert op1 not in stage.ops_waiting_to_retry + assert op2.retry_timer is mock_timer.return_value + assert op2 in stage.ops_waiting_to_retry + assert op3.retry_timer is mock_timer.return_value + assert op3 in stage.ops_waiting_to_retry + assert stage.run_op.call_count == 4 + assert stage.run_op.call_args == mocker.call(op1) + + # Trigger op2's timer to complete + timer2_complete() + + # Only op2 was re-run and had it's timer removed + assert mock_timer.return_value.cancel.call_count == 2 + assert op2.retry_timer is None + assert op2 not in stage.ops_waiting_to_retry + assert op3.retry_timer is mock_timer.return_value + assert op3 in stage.ops_waiting_to_retry + assert stage.run_op.call_count == 5 + assert stage.run_op.call_args == mocker.call(op2) + + # Trigger op3's timer to complete + timer3_complete() + + # op3 has now also been re-run and had it's timer removed + assert op3.retry_timer is None + assert op3 not in stage.ops_waiting_to_retry + assert stage.run_op.call_count == 6 + assert stage.run_op.call_args == mocker.call(op3) + + +@pytest.mark.describe( + "RetryStage - OCCURANCE: Retryable operation completes unsucessfully with a non-retryable error after call to .run_op()" +) +class TestRetryStageRetryableOperationCompletedWithNonRetryableError(RetryStageTestConfig): + @pytest.fixture(params=retryable_ops, ids=[x[0].__name__ for x in retryable_ops]) + def op(self, request, mocker): + op_cls = request.param[0] + init_kwargs = request.param[1] + return op_cls(**init_kwargs) + + @pytest.fixture + def error(self, arbitrary_exception): + return arbitrary_exception + + @pytest.mark.it("Completes normally without retry") + def test_no_retry(self, mocker, stage, op, error, mock_timer): + stage.run_op(op) + op.complete(error=error) + + assert op.completed + assert op not in stage.ops_waiting_to_retry + assert mock_timer.call_count == 0 + + @pytest.mark.it("Cancels and clears the operation's retry timer, if one exists") + def test_cancels_existing_timer(self, mocker, stage, op, error, mock_timer): + # NOTE: This shouldn't happen naturally. We have to artificially create this circumstance + stage.run_op(op) + + # Artificially add a timer. Note that this is already mocked due to the 'mock_timer' fixture + op.retry_timer = threading.Timer(20, fake_callback) + assert op.retry_timer is mock_timer.return_value + + op.complete(error=error) + + assert op.completed + assert mock_timer.return_value.cancel.call_count == 1 + assert op.retry_timer is None + + +@pytest.mark.describe( + "RetryStage - OCCURANCE: Retryable operation completes successfully after call to .run_op()" +) +class TestRetryStageRetryableOperationCompletedSuccessfully(RetryStageTestConfig): + @pytest.fixture(params=retryable_ops, ids=[x[0].__name__ for x in retryable_ops]) + def op(self, request, mocker): + op_cls = request.param[0] + init_kwargs = request.param[1] + return op_cls(**init_kwargs) + + @pytest.mark.it("Completes normally without retry") + def test_no_retry(self, mocker, stage, op, mock_timer): + stage.run_op(op) + op.complete() + + assert op.completed + assert op not in stage.ops_waiting_to_retry + assert mock_timer.call_count == 0 + + # NOTE: this isn't doing anything because arb ops don't trigger callback + @pytest.mark.it("Cancels and clears the operation's retry timer, if one exists") + def test_cancels_existing_timer(self, mocker, stage, op, mock_timer): + # NOTE: This shouldn't happen naturally. We have to artificially create this circumstance + stage.run_op(op) + + # Artificially add a timer. Note that this is already mocked due to the 'mock_timer' fixture + op.retry_timer = threading.Timer(20, fake_callback) + assert op.retry_timer is mock_timer.return_value + + op.complete() + + assert op.completed + assert mock_timer.return_value.cancel.call_count == 1 + assert op.retry_timer is None + + +@pytest.mark.describe( + "RetryStage - OCCURANCE: Non-retryable operation completes after call to .run_op()" +) +class TestRetryStageNonretryableOperationCompleted(RetryStageTestConfig): + @pytest.fixture + def op(self, arbitrary_op): + return arbitrary_op + + @pytest.mark.it("Completes normally without retry, if completed successfully") + def test_successful_completion(self, mocker, stage, op, mock_timer): + stage.run_op(op) + op.complete() + + assert op.completed + assert op not in stage.ops_waiting_to_retry + assert mock_timer.call_count == 0 + + @pytest.mark.it( + "Completes normally without retry, if completed unsucessfully with a non-retryable exception" + ) + def test_unsucessful_non_retryable_err( + self, mocker, stage, op, arbitrary_exception, mock_timer + ): + stage.run_op(op) + op.complete(error=arbitrary_exception) + + assert op.completed + assert op not in stage.ops_waiting_to_retry + assert mock_timer.call_count == 0 + + @pytest.mark.it( + "Completes normally without retry, if completed unsucessfully with a retryable exception" + ) + @pytest.mark.parametrize("exception", retryable_exceptions) + def test_unsucessful_retryable_err(self, mocker, stage, op, exception, mock_timer): + stage.run_op(op) + op.complete(error=exception) + + assert op.completed + assert op not in stage.ops_waiting_to_retry + assert mock_timer.call_count == 0 + + +################### +# RECONNECT STAGE # +################### + + +class ReconnectStageTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_stages_base.ReconnectStage + + @pytest.fixture + def init_kwargs(self, mocker): + return {} + + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + stage.pipeline_root = pipeline_stages_base.PipelineRootStage( + pipeline_configuration=mocker.MagicMock() + ) + mocker.spy(stage, "run_op") + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + return stage + + +class ReconnectStageInstantiationTests(ReconnectStageTestConfig): + @pytest.mark.it("Initializes the 'reconnect_timer' attribute as None") + def test_reconnect_timer(self, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + assert stage.reconnect_timer is None + + @pytest.mark.it("Initializes the 'state' attribute as 'NEVER_CONNECTED'") + def test_state(self, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + assert stage.state == pipeline_stages_base.ReconnectState.NEVER_CONNECTED + + @pytest.mark.it("Initializes the 'waiting_connect_ops' attribute as []") + def test_waiting_connect_ops(self, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + assert stage.waiting_connect_ops == [] + + # TODO: this will not be necessary once retry policy is implemented more fully + @pytest.mark.it("Initializes the 'reconnect_delay' attribute/setting to 10 seconds") + def test_reconnect_delay(self, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + assert stage.reconnect_delay == 10 + + +pipeline_stage_test.add_base_pipeline_stage_tests( + test_module=this_module, + stage_class_under_test=pipeline_stages_base.ReconnectStage, + stage_test_config_class=ReconnectStageTestConfig, + extended_stage_instantiation_test_class=ReconnectStageInstantiationTests, +) + + +@pytest.mark.describe("ReconnectStage - .run_op() -- Called with ConnectOperation") +class TestReconnectStageRunOpWithConnectOperation(ReconnectStageTestConfig, StageRunOpTestBase): + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + stage.pipeline_root = pipeline_stages_base.PipelineRootStage( + pipeline_configuration=mocker.MagicMock() + ) + mocker.spy(stage, "run_op") + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() return stage @pytest.fixture - def connection_op(self, mocker, params): - return params["op_class"](callback=mocker.MagicMock()) + def fake_waiting_connect_ops(self, mocker): + op1 = ArbitraryOperation(callback=mocker.MagicMock()) + op1.original_callback = op1.callback_stack[0] + op2 = ArbitraryOperation(callback=mocker.MagicMock()) + op2.original_callback = op2.callback_stack[0] + return list([op1, op2]) @pytest.fixture - def fake_op(self, mocker): - return FakeOperation(callback=mocker.MagicMock()) - - @pytest.fixture - def fake_ops(self, mocker): - return [ - FakeOperation(callback=mocker.MagicMock()), - FakeOperation(callback=mocker.MagicMock()), - FakeOperation(callback=mocker.MagicMock()), - ] - - @pytest.mark.it( - "Immediately completes a ConnectOperation if the transport is already connected" - ) - def test_connect_while_connected(self, stage, mocker): - op = pipeline_ops_base.ConnectOperation(callback=mocker.MagicMock()) - stage.pipeline_root.connected = True - stage.run_op(op) - assert_callback_succeeded(op=op) - - @pytest.mark.it( - "Immediately completes a DisconnectOperation if the transport is already disconnected" - ) - def test_disconnect_while_disconnected(self, stage, mocker): - op = pipeline_ops_base.DisconnectOperation(callback=mocker.MagicMock()) - stage.pipeline_root.connected = False - stage.run_op(op) - assert_callback_succeeded(op=op) - - @pytest.mark.it( - "Immediately passes the operation down if an operation is not alrady being blocking the stage" - ) - def test_passes_op_when_not_blocked(self, stage, mocker, fake_op): - stage.run_op(fake_op) - assert stage.next.run_op.call_count == 1 - assert stage.next.run_op.call_args[0][0] == fake_op + def op(self, mocker): + return pipeline_ops_base.ConnectOperation(callback=mocker.MagicMock()) @pytest.mark.parametrize( - "params", connection_ops, ids=[x["op_class"].__name__ for x in connection_ops] - ) - @pytest.mark.it( - "Does not immediately pass the operation down if a different operation is currently blcking the stage" - ) - def test_does_not_pass_op_if_blocked(self, params, stage, connection_op, fake_op): - stage.pipeline_root.connected = params["connected_flag_required_to_run"] - stage.run_op(connection_op) - stage.run_op(fake_op) - - assert stage.next.run_op.call_count == 1 - assert stage.next.run_op.call_args[0][0] == connection_op - - @pytest.mark.parametrize( - "params", connection_ops, ids=[x["op_class"].__name__ for x in connection_ops] - ) - @pytest.mark.it( - "Waits for the operation that is currently blocking the stage to complete before passing the op down" - ) - def test_waits_for_serialized_op_to_complete_before_passing_blocked_op( - self, params, stage, connection_op, fake_op - ): - stage.pipeline_root.connected = params["connected_flag_required_to_run"] - stage.run_op(connection_op) - stage.run_op(fake_op) - operation_flow.complete_op(stage=stage.next, op=connection_op) - - assert stage.next.run_op.call_count == 2 - assert stage.next.run_op.call_args[0][0] == fake_op - - @pytest.mark.parametrize( - "params", connection_ops, ids=[x["op_class"].__name__ for x in connection_ops] - ) - @pytest.mark.it("Fails the operation if the operation that previously blocked the stage fails") - def test_fails_blocked_op_if_serialized_op_fails( - self, params, stage, connection_op, fake_op, fake_exception - ): - stage.pipeline_root.connected = params["connected_flag_required_to_run"] - stage.run_op(connection_op) - stage.run_op(fake_op) - connection_op.error = fake_exception - operation_flow.complete_op(stage=stage.next, op=connection_op) - assert_callback_failed(op=fake_op, error=fake_exception) - - @pytest.mark.parametrize( - "params", connection_ops, ids=[x["op_class"].__name__ for x in connection_ops] - ) - @pytest.mark.it( - "Can pend multiple operations while waiting for an operation that is currently blocking the stage" - ) - def test_blocks_multiple_ops(self, params, stage, connection_op, fake_ops): - stage.pipeline_root.connected = params["connected_flag_required_to_run"] - stage.run_op(connection_op) - for op in fake_ops: - stage.run_op(op) - assert stage.next.run_op.call_count == 1 - - @pytest.mark.parametrize( - "params", connection_ops, ids=[x["op_class"].__name__ for x in connection_ops] - ) - @pytest.mark.it( - "Passes down all pending operations after the operation that previously blocked the stage completes successfully" - ) - def test_unblocks_multiple_ops(self, params, stage, connection_op, fake_ops): - stage.pipeline_root.connected = params["connected_flag_required_to_run"] - stage.run_op(connection_op) - for op in fake_ops: - stage.run_op(op) - - operation_flow.complete_op(stage=stage.next, op=connection_op) - - assert stage.next.run_op.call_count == 1 + len(fake_ops) - - # zip our ops and our calls together and make sure they match - run_ops = zip(fake_ops, stage.next.run_op.call_args_list[1:]) - for run_op in run_ops: - op = run_op[0] - call_args = run_op[1] - assert op == call_args[0][0] - - @pytest.mark.parametrize( - "params", connection_ops, ids=[x["op_class"].__name__ for x in connection_ops] - ) - @pytest.mark.it( - "Fails all pending operations after the operation that previously blocked the stage fails" - ) - def test_fails_multiple_ops(self, params, stage, connection_op, fake_ops, fake_exception): - stage.pipeline_root.connected = params["connected_flag_required_to_run"] - stage.run_op(connection_op) - for op in fake_ops: - stage.run_op(op) - - connection_op.error = fake_exception - operation_flow.complete_op(stage=stage.next, op=connection_op) - - for op in fake_ops: - assert_callback_failed(op=op, error=fake_exception) - - @pytest.mark.it( - "Does not immediately pass down operations in the queue if an operation in the queue causes the stage to re-block" - ) - def test_re_blocks_ops_from_queue(self, stage, mocker): - first_connect = pipeline_ops_base.ConnectOperation(callback=mocker.MagicMock()) - first_fake_op = FakeOperation(callback=mocker.MagicMock()) - second_connect = pipeline_ops_base.ReconnectOperation(callback=mocker.MagicMock()) - second_fake_op = FakeOperation(callback=mocker.MagicMock()) - - stage.run_op(first_connect) - stage.run_op(first_fake_op) - stage.run_op(second_connect) - stage.run_op(second_fake_op) - - # at this point, ops are pended waiting for the first connect to complete. Verify this and complete the connect. - assert stage.next.run_op.call_count == 1 - assert stage.next.run_op.call_args[0][0] == first_connect - operation_flow.complete_op(stage=stage.next, op=first_connect) - - # The connect is complete. This passes down first_fake_op and second_connect and second_fake_op gets pended waiting i - # for second_connect to complete. - # Note: this isn't ideal. In a perfect world, second_connect wouldn't start until first_fake_op is complete, but we - # dont have this logic in place yet. - assert stage.next.run_op.call_count == 3 - assert stage.next.run_op.call_args_list[1][0][0] == first_fake_op - assert stage.next.run_op.call_args_list[2][0][0] == second_connect - - # now, complete second_connect to give second_fake_op a chance to get passed down - operation_flow.complete_op(stage=stage.next, op=second_connect) - assert stage.next.run_op.call_count == 4 - assert stage.next.run_op.call_args_list[3][0][0] == second_fake_op - - @pytest.mark.parametrize( - "params", + "state", [ - pytest.param( - { - "pre_connected_flag": True, - "first_connection_op": pipeline_ops_base.DisconnectOperation, - "mid_connect_flag": False, - "second_connection_op": pipeline_ops_base.DisconnectOperation, - }, - id="Disconnect followed by Disconnect", - ), - pytest.param( - { - "pre_connected_flag": False, - "first_connection_op": pipeline_ops_base.ConnectOperation, - "mid_connect_flag": True, - "second_connection_op": pipeline_ops_base.ConnectOperation, - }, - id="Connect followed by Connect", - ), - pytest.param( - { - "pre_connected_flag": True, - "first_connection_op": pipeline_ops_base.ReconnectOperation, - "mid_connect_flag": True, - "second_connection_op": pipeline_ops_base.ConnectOperation, - }, - id="Reconnect followed by Connect", - ), + pipeline_stages_base.ReconnectState.NEVER_CONNECTED, + pipeline_stages_base.ReconnectState.WAITING_TO_RECONNECT, + pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED, + ], + ) + @pytest.mark.it("Does not complete the operation") + def test_does_not_immediately_complete(self, stage, op, state): + stage.state = state + callback = op.callback_stack[0] + stage.run_op(op) + assert callback.call_count == 0 + + @pytest.mark.parametrize( + "state", + [ + pipeline_stages_base.ReconnectState.NEVER_CONNECTED, + pipeline_stages_base.ReconnectState.WAITING_TO_RECONNECT, + pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED, + ], + ) + @pytest.mark.it("adds the op to the waiting_connect_ops list") + def test_adds_to_waiting_connect_ops(self, stage, op, state, fake_waiting_connect_ops): + stage.state = state + stage.waiting_connect_ops = fake_waiting_connect_ops + waiting_connect_ops_copy = list(fake_waiting_connect_ops) + stage.run_op(op) + waiting_connect_ops_copy.append(op) + assert stage.waiting_connect_ops == waiting_connect_ops_copy + + @pytest.mark.parametrize( + "state", + [ + pipeline_stages_base.ReconnectState.NEVER_CONNECTED, + pipeline_stages_base.ReconnectState.WAITING_TO_RECONNECT, + pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED, + ], + ) + @pytest.mark.it("does not complete any waiting ops") + def test_does_not_complete_waiting_connect_ops( + self, stage, op, state, fake_waiting_connect_ops + ): + stage.state = state + stage.waiting_connect_ops = fake_waiting_connect_ops + waiting_connect_ops_copy = list(fake_waiting_connect_ops) + stage.run_op(op) + for op in waiting_connect_ops_copy: + assert op.original_callback.call_count == 0 + + @pytest.mark.parametrize( + "state", + [ + pipeline_stages_base.ReconnectState.NEVER_CONNECTED, + pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED, + ], + ) + @pytest.mark.it("Sends a new connect op down") + def test_sends_new_op_down(self, stage, op, state): + stage.state = state + stage.run_op(op) + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_base.ConnectOperation) + assert new_op != op + + @pytest.mark.parametrize("state", [pipeline_stages_base.ReconnectState.WAITING_TO_RECONNECT]) + @pytest.mark.it("Does not send a new connect op down") + def test_does_not_send_new_op_down(self, stage, op, state): + stage.state = state + stage.run_op(op) + assert stage.send_op_down.call_count == 0 + + @pytest.mark.parametrize( + "state", + [ + pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED, + pipeline_stages_base.ReconnectState.WAITING_TO_RECONNECT, + pipeline_stages_base.ReconnectState.NEVER_CONNECTED, + ], + ) + @pytest.mark.it("Does not change the state") + def test_does_not_change_state(self, stage, op, state): + stage.state = state + stage.run_op(op) + assert stage.state == state + + @pytest.mark.parametrize( + "state", + [ + pipeline_stages_base.ReconnectState.NEVER_CONNECTED, + pipeline_stages_base.ReconnectState.WAITING_TO_RECONNECT, + pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED, + ], + ) + @pytest.mark.it("Does not cancel, clear or set a reconnect timer") + def test_timer_untouched(self, mocker, stage, op, mock_timer, state): + stage.state = state + original_timer = stage.reconnect_timer + stage.run_op(op) + + assert stage.reconnect_timer is original_timer + if stage.reconnect_timer: + assert stage.reconnect_timer.cancel.call_count == 0 + assert mock_timer.call_count == 0 + + +@pytest.mark.describe("ReconnectStage - .run_op() -- Called with DisconnectOperation") +class TestReconnectStageRunOpWithDisconnectOperation(ReconnectStageTestConfig, StageRunOpTestBase): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_base.DisconnectOperation(callback=mocker.MagicMock()) + + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + stage.pipeline_root = pipeline_stages_base.PipelineRootStage( + pipeline_configuration=mocker.MagicMock() + ) + mocker.spy(stage, "run_op") + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + return stage + + @pytest.fixture + def fake_waiting_connect_ops(self, mocker): + op1 = ArbitraryOperation(callback=mocker.MagicMock()) + op1.original_callback = op1.callback_stack[0] + op2 = ArbitraryOperation(callback=mocker.MagicMock()) + op2.original_callback = op2.callback_stack[0] + return list([op1, op2]) + + @pytest.mark.parametrize("state", [pipeline_stages_base.ReconnectState.WAITING_TO_RECONNECT]) + @pytest.mark.it("Immediately completes the op") + def test_completes_op(self, stage, op, state, mocker): + stage.state = state + callback = op.callback_stack[0] + stage.run_op(op) + assert callback.call_count == 1 + assert callback.call_args == mocker.call(op=op, error=None) + + @pytest.mark.parametrize( + "state", + [ + pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED, + pipeline_stages_base.ReconnectState.NEVER_CONNECTED, + ], + ) + @pytest.mark.it("Does not immediately complete the op") + def test_does_not_complete_op(self, stage, op, state): + stage.state = state + callback = op.callback_stack[0] + stage.run_op(op) + assert callback.call_count == 0 + + @pytest.mark.parametrize( + "state", + [ + pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED, + pipeline_stages_base.ReconnectState.NEVER_CONNECTED, + ], + ) + @pytest.mark.it("Sends the op down") + def test_sends_op_down(self, stage, op, state, mocker): + stage.state = state + stage.run_op(op) + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) + + @pytest.mark.parametrize("state", [pipeline_stages_base.ReconnectState.WAITING_TO_RECONNECT]) + @pytest.mark.it("Does not send the op down") + def test_does_not_send_op_down(self, stage, op, state): + stage.state = state + stage.run_op(op) + assert stage.send_op_down.call_count == 0 + + @pytest.mark.parametrize("state", [pipeline_stages_base.ReconnectState.WAITING_TO_RECONNECT]) + @pytest.mark.it("Clears the reconnect timer") + def test_clears_reconnect_timer(self, stage, op, state, mocker): + stage.state = state + reconnect_timer = mocker.MagicMock() + stage.reconnect_timer = reconnect_timer + stage.run_op(op) + assert stage.reconnect_timer is None + assert reconnect_timer.cancel.call_count == 1 + assert reconnect_timer.cancel.call_args == mocker.call() + + @pytest.mark.parametrize( + "state", + [ + pipeline_stages_base.ReconnectState.NEVER_CONNECTED, + pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED, + ], + ) + @pytest.mark.it("Does not cancel, clear or set a reconnect timer") + def test_timer_untouched(self, mocker, stage, op, mock_timer, state): + stage.state = state + original_timer = stage.reconnect_timer + stage.run_op(op) + + assert stage.reconnect_timer is original_timer + if stage.reconnect_timer: + assert stage.reconnect_timer.cancel.call_count == 0 + assert mock_timer.call_count == 0 + + @pytest.mark.parametrize("state", [pipeline_stages_base.ReconnectState.WAITING_TO_RECONNECT]) + @pytest.mark.it("Changes the state to CONNECTED_OR_DISCONNECTED") + def test_changes_state(self, stage, op, state): + stage.state = state + stage.run_op(op) + assert stage.state == pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED + + @pytest.mark.parametrize( + "state", + [ + pipeline_stages_base.ReconnectState.NEVER_CONNECTED, + pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED, + ], + ) + @pytest.mark.it("Does not change the state") + def test_does_not_change_state(self, stage, op, state): + stage.state = state + stage.run_op(op) + assert stage.state == state + + @pytest.mark.parametrize("state", [pipeline_stages_base.ReconnectState.WAITING_TO_RECONNECT]) + @pytest.mark.it("Cancels all ops in the waiting list") + def test_cancels_waiting_connect_ops(self, stage, op, state, fake_waiting_connect_ops): + stage.state = state + stage.waiting_connect_ops = fake_waiting_connect_ops + waiting_connect_ops_copy = list(fake_waiting_connect_ops) + stage.run_op(op) + assert stage.waiting_connect_ops == [] + for op in waiting_connect_ops_copy: + assert op.original_callback.call_count == 1 + error = op.original_callback.call_args[1]["error"] + assert isinstance(error, pipeline_exceptions.OperationCancelled) + + @pytest.mark.parametrize( + "state", + [ + pipeline_stages_base.ReconnectState.NEVER_CONNECTED, + pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED, + ], + ) + @pytest.mark.it("Does not add, remove, or complete any ops in the waiting ops list") + def test_waiting_connect_ops_list_untouched(self, stage, op, state, fake_waiting_connect_ops): + stage.state = state + stage.waiting_connect_ops = fake_waiting_connect_ops + waiting_connect_ops_copy = list(fake_waiting_connect_ops) + stage.run_op(op) + assert stage.waiting_connect_ops == waiting_connect_ops_copy + for op in stage.waiting_connect_ops: + assert op.original_callback.call_count == 0 + + +@pytest.mark.describe("ReconnectStage - .run_op() -- Called with arbitrary other operation") +class TestReconnectStageRunOpWithArbitraryOperation(ReconnectStageTestConfig, StageRunOpTestBase): + @pytest.fixture + def op(self, arbitrary_op): + return arbitrary_op + + @pytest.fixture( + params=[ + pipeline_stages_base.ReconnectState.NEVER_CONNECTED, + pipeline_stages_base.ReconnectState.WAITING_TO_RECONNECT, + pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED, + ] + ) + def state(self, request): + return request.param + + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs, state): + stage = cls_type(**init_kwargs) + stage.pipeline_root = pipeline_stages_base.PipelineRootStage( + pipeline_configuration=mocker.MagicMock() + ) + mocker.spy(stage, "run_op") + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + stage.state = state + return stage + + @pytest.fixture + def fake_waiting_connect_ops(self, mocker): + op1 = ArbitraryOperation(callback=mocker.MagicMock()) + op1.original_callback = op1.callback_stack[0] + op2 = ArbitraryOperation(callback=mocker.MagicMock()) + op2.original_callback = op2.callback_stack[0] + return list([op1, op2]) + + @pytest.mark.it("Does not change the state") + def test_state_unchanged(self, stage, op): + original_state = stage.state + stage.run_op(op) + assert stage.state is original_state + + @pytest.mark.it("Sends the operation down the pipeline") + def test_sends_op_down(self, mocker, stage, op): + stage.run_op(op) + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) + + @pytest.mark.it("Does not cancel, clear or set a reconnect timer") + def test_timer_untouched(self, mocker, stage, op, mock_timer): + original_timer = stage.reconnect_timer + stage.run_op(op) + + assert stage.reconnect_timer is original_timer + if stage.reconnect_timer: + assert stage.reconnect_timer.cancel.call_count == 0 + assert mock_timer.call_count == 0 + + @pytest.mark.it("Does not add, remove, or complete any ops in the waiting ops list") + def test_waiting_connect_ops_list_untouched(self, stage, op, state, fake_waiting_connect_ops): + stage.state = state + stage.waiting_connect_ops = fake_waiting_connect_ops + waiting_connect_ops_copy = list(fake_waiting_connect_ops) + stage.run_op(op) + assert stage.waiting_connect_ops == waiting_connect_ops_copy + for op in stage.waiting_connect_ops: + assert op.original_callback.call_count == 0 + + +@pytest.mark.describe("ReconnectStage - .handle_pipeline_event() -- Called with a ConnectedEvent") +class TestReconnectStageHandlePipelineEventWithConnectedEvent( + ReconnectStageTestConfig, StageHandlePipelineEventTestBase +): + @pytest.fixture( + params=[ + pipeline_stages_base.ReconnectState.NEVER_CONNECTED, + pipeline_stages_base.ReconnectState.WAITING_TO_RECONNECT, + pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED, + ] + ) + def state(self, request): + return request.param + + @pytest.fixture(params=[True, False], ids=["Connected", "Disconnected"]) + def connected(self, request): + return request.param + + @pytest.fixture + def fake_waiting_connect_ops(self, mocker): + op1 = ArbitraryOperation(callback=mocker.MagicMock()) + op1.original_callback = op1.callback_stack[0] + op2 = ArbitraryOperation(callback=mocker.MagicMock()) + op2.original_callback = op2.callback_stack[0] + return list([op1, op2]) + + @pytest.fixture( + params=[True, False], ids=["Existing Reconnect Timer", "No Existing Reconnect Timer"] + ) + def reconnect_timer(self, request, mocker): + if request.param: + return mocker.MagicMock() + else: + return None + + @pytest.fixture() + def stage(self, mocker, cls_type, init_kwargs, connected, reconnect_timer): + stage = cls_type(**init_kwargs) + stage.pipeline_root = pipeline_stages_base.PipelineRootStage( + pipeline_configuration=mocker.MagicMock() + ) + mocker.spy(stage, "run_op") + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + stage.pipeline_root.connected = connected + stage.reconnect_timer = reconnect_timer + return stage + + @pytest.fixture + def event(self): + return pipeline_events_base.ConnectedEvent() + + @pytest.mark.it("Sends the event up the pipeline") + def test_sends_event_up(self, mocker, stage, event, state): + stage.state = state + stage.handle_pipeline_event(event) + assert stage.send_event_up.call_count == 1 + assert stage.send_event_up.call_args == mocker.call(event) + + @pytest.mark.it("Does not add, remove, or complete any ops in the waiting ops list") + def test_waiting_connect_ops_list_untouched( + self, stage, event, state, fake_waiting_connect_ops + ): + stage.state = state + stage.waiting_connect_ops = fake_waiting_connect_ops + waiting_connect_ops_copy = list(fake_waiting_connect_ops) + stage.handle_pipeline_event(event) + assert stage.waiting_connect_ops == waiting_connect_ops_copy + for op in stage.waiting_connect_ops: + assert op.original_callback.call_count == 0 + + @pytest.mark.it("Does not cancel, clear or set a reconnect timer") + def test_timer_untouched(self, mocker, stage, event, mock_timer): + original_timer = stage.reconnect_timer + stage.handle_pipeline_event(event) + + assert stage.reconnect_timer is original_timer + if stage.reconnect_timer: + assert stage.reconnect_timer.cancel.call_count == 0 + assert mock_timer.call_count == 0 + + +@pytest.mark.describe( + "ReconnectStage - .handle_pipeline_event() -- Called with a DisconnectedEvent" +) +class TestReconnectStageHandlePipelineEventWithDisconnectedEvent( + ReconnectStageTestConfig, StageHandlePipelineEventTestBase +): + @pytest.fixture( + params=[True, False], ids=["Existing Reconnect Timer", "No Existing Reconnect Timer"] + ) + def reconnect_timer(self, request, mocker): + if request.param: + return mocker.MagicMock() + else: + return None + + @pytest.fixture( + params=[ + pipeline_stages_base.ReconnectState.NEVER_CONNECTED, + pipeline_stages_base.ReconnectState.WAITING_TO_RECONNECT, + pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED, + ] + ) + def state(self, request): + return request.param + + @pytest.fixture() + def stage(self, mocker, cls_type, init_kwargs, state, reconnect_timer, mock_timer): + # mock_timer fixture is used here so none of these tests create an actual timer. + stage = cls_type(**init_kwargs) + stage.pipeline_root = pipeline_stages_base.PipelineRootStage( + pipeline_configuration=mocker.MagicMock() + ) + stage.state = state + stage.reconnect_timer = reconnect_timer + mocker.spy(stage, "run_op") + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + return stage + + @pytest.fixture + def event(self): + return pipeline_events_base.DisconnectedEvent() + + @pytest.mark.it("If previously connected, changes the state to WAITING_TO_RECONNECT") + def test_changes_state(self, stage, event): + stage.pipeline_root.connected = True + stage.handle_pipeline_event(event) + assert stage.state == pipeline_stages_base.ReconnectState.WAITING_TO_RECONNECT + + @pytest.mark.it("If not previously connectred, does not change the state") + def test_does_not_change_state(self, stage, event): + stage.pipeline_root.connected = False + original_state = stage.state + stage.handle_pipeline_event(event) + assert stage.state == original_state + + @pytest.mark.it("If previously connected, clears the previous reconnect timer if there was one") + def test_clears_reconnect_timer(self, stage, event): + old_timer = stage.reconnect_timer + stage.pipeline_root.connected = True + stage.handle_pipeline_event(event) + if old_timer: + assert old_timer.cancel.call_count == 1 + assert stage.reconnect_timer != old_timer + + @pytest.mark.it( + "If not previously connected, does not clears the previous reconnect timer if there was one" + ) + def test_does_not_clear_reconnect_timer(self, stage, event): + old_timer = stage.reconnect_timer + stage.pipeline_root.connected = False + stage.handle_pipeline_event(event) + if old_timer: + assert old_timer.cancel.call_count == 0 + assert stage.reconnect_timer == old_timer + + @pytest.mark.it("If previously connected, sets a new reconnect timer") + def test_sets_new_reconnect_timer(self, stage, event, mock_timer): + stage.pipeline_root.connected = True + stage.handle_pipeline_event(event) + assert mock_timer.call_count == 1 + assert stage.reconnect_timer == mock_timer.return_value + assert stage.reconnect_timer.start.call_count == 1 + + @pytest.mark.it("If not previously connected, does not set a new reconnect timer") + def test_does_not_set_new_reconnect_timer(self, stage, event, mock_timer): + old_reconnect_timer = stage.reconnect_timer + stage.pipeline_root.connected = False + stage.handle_pipeline_event(event) + assert mock_timer.call_count == 0 + assert stage.reconnect_timer == old_reconnect_timer + if stage.reconnect_timer: + assert stage.reconnect_timer.start.call_count == 0 + + @pytest.mark.parametrize( + "previously_connected", + [ + pytest.param(True, id="Previously conencted"), + pytest.param(False, id="Not previously connected"), + ], + ) + @pytest.mark.it("Sends the event up") + def test_sends_event_up(self, stage, event, previously_connected, mocker): + stage.pipeline_root.connected = previously_connected + stage.handle_pipeline_event(event) + assert stage.send_event_up.call_count == 1 + assert stage.send_event_up.call_args == mocker.call(event) + + +@pytest.mark.describe( + "ReconnectStage - .handle_pipeline_event() -- Called with some other arbitrary event" +) +class TestReconnectStageHandlePipelineEventWithArbitraryEvent( + ReconnectStageTestConfig, StageHandlePipelineEventTestBase +): + @pytest.fixture( + params=[ + pipeline_stages_base.ReconnectState.NEVER_CONNECTED, + pipeline_stages_base.ReconnectState.WAITING_TO_RECONNECT, + pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED, + ] + ) + def state(self, request): + return request.param + + @pytest.fixture( + params=[True, False], ids=["Existing Reconnect Timer", "No Existing Reconnect Timer"] + ) + def reconnect_timer(self, request, mocker): + if request.param: + return mocker.MagicMock() + else: + return None + + @pytest.fixture() + def stage(self, mocker, cls_type, init_kwargs, state, reconnect_timer): + stage = cls_type(**init_kwargs) + stage.pipeline_root = pipeline_stages_base.PipelineRootStage( + pipeline_configuration=mocker.MagicMock() + ) + mocker.spy(stage, "run_op") + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + stage.state = state + stage.reconnect_timer = reconnect_timer + return stage + + @pytest.fixture + def event(self, arbitrary_event): + return arbitrary_event + + @pytest.mark.it("Sends the event up the pipeline") + def test_sends_up(self, mocker, stage, event): + stage.handle_pipeline_event(event) + + assert stage.send_event_up.call_count == 1 + assert stage.send_event_up.call_args == mocker.call(event) + + @pytest.mark.it("Does not change the state") + def test_state_unchanged(self, stage, event): + original_state = stage.state + stage.handle_pipeline_event(event) + assert stage.state is original_state + + @pytest.mark.it("Does not cancel, clear or set a reconnect timer") + def test_timer_untouched(self, mocker, stage, event, mock_timer): + original_timer = stage.reconnect_timer + stage.handle_pipeline_event(event) + + assert stage.reconnect_timer is original_timer + if stage.reconnect_timer: + assert stage.reconnect_timer.cancel.call_count == 0 + assert mock_timer.call_count == 0 + + +@pytest.mark.describe("ReconnectStage - OCCURANCE: Reconnect Timer expires") +class TestReconnectStageReconnectTimerExpires(ReconnectStageTestConfig): + @pytest.fixture() + def stage(self, mocker, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + stage.pipeline_root = pipeline_stages_base.PipelineRootStage( + pipeline_configuration=mocker.MagicMock() + ) + mocker.spy(stage, "run_op") + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + + return stage + + @pytest.fixture + def trigger_stage_retry_timer_completion(self, stage, mock_timer): + # The stage must be connected in order to set a reconnect timer + stage.pipeline_root.connected = True + + # Send a DisconnectedEvent to the stage in order to set up the timer + stage.handle_pipeline_event(pipeline_events_base.DisconnectedEvent()) + + # Get timer completion callback + assert mock_timer.call_count == 1 + timer_callback = mock_timer.call_args[0][1] + return timer_callback + + @pytest.mark.parametrize("state", [pipeline_stages_base.ReconnectState.WAITING_TO_RECONNECT]) + @pytest.mark.it("Creates a new ConnectOperation and sends it down the pipeline") + def test_pipeline_disconnected( + self, mocker, stage, trigger_stage_retry_timer_completion, state + ): + stage.state = state + mock_connect_op = mocker.patch.object(pipeline_ops_base, "ConnectOperation") + + trigger_stage_retry_timer_completion() + + assert mock_connect_op.call_count == 1 + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(mock_connect_op.return_value) + + @pytest.mark.parametrize( + "state", + [ + pipeline_stages_base.ReconnectState.NEVER_CONNECTED, + pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED, + ], + ) + @pytest.mark.it("Does not create a new ConnectOperation and send it down the pipeline") + def test_pipeline_connected(self, mocker, stage, trigger_stage_retry_timer_completion, state): + stage.state = state + mock_connect_op = mocker.patch.object(pipeline_ops_base, "ConnectOperation") + + trigger_stage_retry_timer_completion() + + assert mock_connect_op.call_count == 0 + assert stage.send_op_down.call_count == 0 + + @pytest.mark.parametrize( + "state", + [ + pipeline_stages_base.ReconnectState.NEVER_CONNECTED, + pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED, + pipeline_stages_base.ReconnectState.WAITING_TO_RECONNECT, + ], + ) + @pytest.mark.it("Sets self.reconnect_timer to None") + def test_sets_reconnect_timer_to_none( + self, mocker, stage, trigger_stage_retry_timer_completion, state + ): + stage.state = state + trigger_stage_retry_timer_completion() + assert stage.reconnect_timer is None + + @pytest.mark.parametrize("state", [pipeline_stages_base.ReconnectState.WAITING_TO_RECONNECT]) + @pytest.mark.it("Changes the state to CONNECTED_OR_DISCONNECTED") + def test_changes_state(self, mocker, stage, trigger_stage_retry_timer_completion, state): + stage.state = state + trigger_stage_retry_timer_completion() + assert stage.state == pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED + + @pytest.mark.parametrize( + "state", + [ + pipeline_stages_base.ReconnectState.NEVER_CONNECTED, + pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED, + ], + ) + @pytest.mark.it("Does not change the state") + def test_does_not_change_state( + self, mocker, stage, trigger_stage_retry_timer_completion, state + ): + stage.state = state + trigger_stage_retry_timer_completion() + assert stage.state == state + + +@pytest.mark.describe( + "ReconnectStage - OCCURANCE: ConnectOperation that was created in order to reconnect is completed" +) +class TestReconnectStageConnectOperationForReconnectIsCompleted(ReconnectStageTestConfig): + @pytest.fixture( + params=[ + pipeline_exceptions.OperationCancelled, + pipeline_exceptions.PipelineTimeoutError, + pipeline_exceptions.OperationError, + transport_exceptions.ConnectionFailedError, + transport_exceptions.ConnectionDroppedError, + ] + ) + def transient_connect_exception(self, request): + return request.param() + + @pytest.fixture() + def stage(self, mocker, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + stage.pipeline_root = pipeline_stages_base.PipelineRootStage( + pipeline_configuration=mocker.MagicMock() + ) + mocker.spy(stage, "run_op") + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + + return stage + + @pytest.fixture( + params=[ + pytest.param(True, id="First connect attempt"), + pytest.param(False, id="Second connect attempt"), + ] + ) + def connect_op(self, stage, request, mocker, mock_timer): + first_connect_attempt = request.param + + if first_connect_attempt: + stage.run_op(pipeline_ops_base.ConnectOperation(callback=mocker.MagicMock())) + else: + # The stage must be connected and virtually connected in order to set a reconnect timer + stage.pipeline_root.connected = True + + # Send a DisconnectedEvent to the stage in order to set up the timer + stage.handle_pipeline_event(pipeline_events_base.DisconnectedEvent()) + + # Get timer completion callback + assert mock_timer.call_count == 1 + timer_callback = mock_timer.call_args[0][1] + mock_timer.reset_mock() + + # Force trigger the reconnect timer completion in order to trigger a reconnect + timer_callback() + + # Get the connect operation sent down as part of the reconnect + assert stage.send_op_down.call_count == 1 + connect_op = stage.send_op_down.call_args[0][0] + assert isinstance(connect_op, pipeline_ops_base.ConnectOperation) + return connect_op + + @pytest.fixture( + params=[ + pipeline_stages_base.ReconnectState.NEVER_CONNECTED, + pipeline_stages_base.ReconnectState.WAITING_TO_RECONNECT, + pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED, + ] + ) + def all_states(self, request): + return request.param + + @pytest.fixture + def fake_waiting_connect_ops(self, mocker): + op1 = ArbitraryOperation(callback=mocker.MagicMock()) + op1.original_callback = op1.callback_stack[0] + op2 = ArbitraryOperation(callback=mocker.MagicMock()) + op2.original_callback = op2.callback_stack[0] + return list([op1, op2]) + + @pytest.mark.it("Sets the state to CONNECTED_OR_DISCONNECTED if the connect succeeds") + def test_sets_state_on_success(self, stage, connect_op, all_states): + stage.state = all_states + connect_op.complete() + assert stage.state == pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED + + @pytest.mark.it("Clears and sets reconnect_timer to None if the connect succeeds") + def test_clears_reconnect_timer_on_success(self, stage, connect_op, all_states, mocker): + stage.state = all_states + reconnect_timer = mocker.MagicMock() + stage.reconnect_timer = reconnect_timer + connect_op.complete() + assert stage.reconnect_timer is None + assert reconnect_timer.cancel.call_count == 1 + + @pytest.mark.it("Does not create a new reconnect timer on success") + def test_does_not_create_new_reconnect_timer_on_success( + self, stage, connect_op, all_states, mock_timer + ): + stage.state = all_states + connect_op.complete() + assert stage.reconnect_timer is None + + @pytest.mark.it("Completes any waiting ops if the connect succeeds") + def test_completes_waiting_connect_ops( + self, stage, connect_op, all_states, fake_waiting_connect_ops, mocker + ): + stage.state = all_states + stage.waiting_connect_ops = list(fake_waiting_connect_ops) + connect_op.complete() + assert stage.waiting_connect_ops == [] + for op in fake_waiting_connect_ops: + assert op.callback_stack == [] + assert op.original_callback.call_count == 1 + assert op.original_callback.call_args == mocker.call(op=op, error=None) + + @pytest.mark.parametrize( + "state", + [ + pipeline_stages_base.ReconnectState.NEVER_CONNECTED, + pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED, + ], + ) + @pytest.mark.it("Does not change state if the connection fails with an arbitrary error") + def test_does_not_change_state_on_arbitrary_exception( + self, stage, connect_op, state, arbitrary_exception + ): + stage.state = state + connect_op.complete(error=arbitrary_exception) + assert stage.state == state + + @pytest.mark.parametrize( + "state", + [ + pipeline_stages_base.ReconnectState.NEVER_CONNECTED, + pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED, ], ) @pytest.mark.it( - "Immediately completes a second op which was waiting for a first op that succeeded" + "Does not create a new reconnect timer if the connection fails with an arbitrary error" ) - def test_immediately_completes_second_op(self, stage, params, mocker): - first_connection_op = params["first_connection_op"](mocker.MagicMock()) - second_connection_op = params["second_connection_op"](mocker.MagicMock()) - stage.pipeline_root.connected = params["pre_connected_flag"] - - stage.run_op(first_connection_op) - stage.run_op(second_connection_op) - - # first_connection_op has been passed down. second_connection_op is waiting for first disconnect to complete. - assert stage.next.run_op.call_count == 1 - assert stage.next.run_op.call_args[0][0] == first_connection_op - - # complete first_connection_op - stage.pipeline_root.connected = params["mid_connect_flag"] - operation_flow.complete_op(stage=stage.next, op=first_connection_op) - - # second connect_op should be completed without having been passed down. - assert stage.next.run_op.call_count == 1 - assert_callback_succeeded(op=second_connection_op) - - -pipeline_stage_test.add_base_pipeline_stage_tests( - cls=pipeline_stages_base.CoordinateRequestAndResponseStage, - module=this_module, - all_ops=all_common_ops, - handled_ops=[pipeline_ops_base.SendIotRequestAndWaitForResponseOperation], - all_events=all_common_events, - handled_events=[pipeline_events_base.IotResponseEvent], - extra_initializer_defaults={"pending_responses": dict}, -) - - -fake_request_type = "__fake_request_type__" -fake_method = "__fake_method__" -fake_resource_location = "__fake_resource_location__" -fake_request_body = "__fake_request_body__" -fake_status_code = "__fake_status_code__" -fake_response_body = "__fake_response_body__" -fake_request_id = "__fake_request_id__" - - -def make_fake_request_and_response(mocker): - return pipeline_ops_base.SendIotRequestAndWaitForResponseOperation( - request_type=fake_request_type, - method=fake_method, - resource_location=fake_resource_location, - request_body=fake_request_body, - callback=mocker.MagicMock(), - ) - - -@pytest.mark.describe( - "CoordinateRequestAndResponse - .run_op() -- called with SendIotRequestAndWaitForResponseOperation" -) -class TestCoordinateRequestAndResponseSendIotRequestRunOp(object): - @pytest.fixture - def op(self, mocker): - return make_fake_request_and_response(mocker) - - @pytest.fixture - def stage(self, mocker): - return make_mock_stage(mocker, pipeline_stages_base.CoordinateRequestAndResponseStage) - - @pytest.mark.it( - "Sends an SendIotRequestOperation op to the next stage with the same parameters and a newly allocated request_id" - ) - def test_sends_op_and_validates_new_op(self, stage, op): - stage.run_op(op) - assert stage.next.run_op.call_count == 1 - new_op = stage.next.run_op.call_args[0][0] - assert isinstance(new_op, pipeline_ops_base.SendIotRequestOperation) - assert new_op.request_type == op.request_type - assert new_op.method == op.method - assert new_op.resource_location == op.resource_location - assert new_op.request_body == op.request_body - assert new_op.request_id - - @pytest.mark.it("Does not complete the SendIotRequestAndwaitForResponse op") - def test_sends_op_and_verifies_no_response(self, stage, op): - stage.run_op(op) - assert op.callback.call_count == 0 - - @pytest.mark.it("Fails SendIotRequestAndWaitForResponseOperation if there is no next stage") - def test_no_next_stage(self, stage, op): - stage.next = None - stage.run_op(op) - assert_callback_failed(op=op) - - @pytest.mark.it("Generates a new request_id for every operation") - def test_sends_two_ops_and_validates_request_id(self, stage, op, mocker): - op2 = make_fake_request_and_response(mocker) - stage.run_op(op) - stage.run_op(op2) - assert stage.next.run_op.call_count == 2 - new_op = stage.next.run_op.call_args_list[0][0][0] - new_op2 = stage.next.run_op.call_args_list[1][0][0] - assert new_op.request_id != new_op2.request_id - - @pytest.mark.it( - "Fails SendIotRequestAndWaitForResponseOperation if an Exception is raised in the SendIotRequestOperation op" - ) - def test_new_op_raises_exception(self, stage, op, mocker): - stage.next._execute_op = mocker.Mock(side_effect=Exception) - stage.run_op(op) - assert_callback_failed(op=op) - - @pytest.mark.it("Allows BaseExceptions rised on the SendIotRequestOperation op to propogate") - def test_new_op_raises_base_exception(self, stage, op, mocker): - stage.next._execute_op = mocker.Mock(side_effect=UnhandledException) - with pytest.raises(UnhandledException): - stage.run_op(op) - assert op.callback.call_count == 0 - - -@pytest.mark.describe( - "CoordinateRequestAndResponseStage - .handle_pipeline_event() -- called with IotResponseEvent" -) -class TestCoordinateRequestAndResponseSendIotRequestHandleEvent(object): - @pytest.fixture - def op(self, mocker): - return make_fake_request_and_response(mocker) - - @pytest.fixture - def stage(self, mocker): - return make_mock_stage(mocker, pipeline_stages_base.CoordinateRequestAndResponseStage) - - @pytest.fixture - def iot_request(self, stage, op): - stage.run_op(op) - return stage.next.run_op.call_args[0][0] - - @pytest.fixture - def iot_response(self, stage, iot_request): - return pipeline_events_base.IotResponseEvent( - request_id=iot_request.request_id, - status_code=fake_status_code, - response_body=fake_response_body, - ) - - @pytest.mark.it( - "Completes the SendIotRequestAndWaitForResponseOperation op with the matching request_id including response_body and status_code" - ) - def test_completes_op_with_matching_request_id(self, stage, op, iot_response): - operation_flow.pass_event_to_previous_stage(stage.next, iot_response) - assert_callback_succeeded(op=op) - assert op.status_code == iot_response.status_code - assert op.response_body == iot_response.response_body - - @pytest.mark.it( - "Calls the unhandled error handler if there is no previous stage when request_id matches" - ) - def test_matching_request_id_with_no_previous_stage( - self, stage, op, iot_response, unhandled_error_handler + def test_does_not_create_new_reconnect_timer_on_arbitrary_exception( + self, stage, connect_op, state, mock_timer, arbitrary_exception ): - stage.next.previous = None - operation_flow.pass_event_to_previous_stage(stage.next, iot_response) - assert unhandled_error_handler.call_count == 1 + stage.state = state + connect_op.complete(error=arbitrary_exception) + assert stage.reconnect_timer is None + assert mock_timer.call_count == 0 + + @pytest.mark.parametrize("state", [pipeline_stages_base.ReconnectState.WAITING_TO_RECONNECT]) + @pytest.mark.it( + "Clears and sets reconnect_timer to None if the connection fails with an arbitrary error" + ) + def test_clears_reconnect_timer_on_arbitrary_exception( + self, stage, connect_op, state, mocker, arbitrary_exception + ): + stage.state = state + reconnect_timer = mocker.MagicMock() + stage.reconnect_timer = reconnect_timer + connect_op.complete(error=arbitrary_exception) + assert stage.reconnect_timer is None + assert reconnect_timer.cancel.call_count == 1 + + @pytest.mark.parametrize("state", [pipeline_stages_base.ReconnectState.WAITING_TO_RECONNECT]) + @pytest.mark.it( + "Changes the state to CONNECTED_OR_DISCONNECTED if the connection fails with an arbitrary error" + ) + def test_changes_state_on_arbitrary_exception( + self, stage, connect_op, state, arbitrary_exception + ): + stage.state = state + connect_op.complete(error=arbitrary_exception) + assert stage.state == pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED @pytest.mark.it( - "Does nothing if an IotResponse with an identical request_id is received a second time" + "Completes all waiting ops with the arbitrary failure if the connection fails with an arbitrary error" ) - def test_ignores_duplicate_request_id(self, stage, op, iot_response, unhandled_error_handler): - operation_flow.pass_event_to_previous_stage(stage.next, iot_response) - assert_callback_succeeded(op=op) - op.callback.reset_mock() - - operation_flow.pass_event_to_previous_stage(stage.next, iot_response) - assert op.callback.call_count == 0 - assert unhandled_error_handler.call_count == 0 + def test_completes_waiting_connect_ops_on_arbitrary_exception( + self, stage, connect_op, all_states, fake_waiting_connect_ops, arbitrary_exception, mocker + ): + stage.state = all_states + stage.waiting_connect_ops = list(fake_waiting_connect_ops) + connect_op.complete(error=arbitrary_exception) + assert stage.waiting_connect_ops == [] + for op in fake_waiting_connect_ops: + assert op.callback_stack == [] + assert op.original_callback.call_count == 1 + assert op.original_callback.call_args == mocker.call(op=op, error=arbitrary_exception) + @pytest.mark.parametrize("state", [pipeline_stages_base.ReconnectState.NEVER_CONNECTED]) @pytest.mark.it( - "Does nothing if an IotResponse with a request_id is received for an operation that returned failure" + "Completes all waiting ops with the transient failure if the connection fails with a transient error" ) - def test_ignores_request_id_from_failure(self, stage, op, mocker, unhandled_error_handler): - stage.next._execute_op = mocker.MagicMock(side_effect=Exception) - stage.run_op(op) + def test_completes_all_waiting_connect_ops_on_transient_connect_exception( + self, + stage, + connect_op, + state, + fake_waiting_connect_ops, + transient_connect_exception, + mocker, + ): + stage.state = state + stage.waiting_connect_ops = list(fake_waiting_connect_ops) + connect_op.complete(error=transient_connect_exception) + assert stage.waiting_connect_ops == [] + for op in fake_waiting_connect_ops: + assert op.callback_stack == [] + assert op.original_callback.call_count == 1 + assert op.original_callback.call_args == mocker.call( + op=op, error=transient_connect_exception + ) - req = stage.next.run_op.call_args[0][0] - resp = pipeline_events_base.IotResponseEvent( - request_id=req.request_id, - status_code=fake_status_code, - response_body=fake_response_body, - ) + @pytest.mark.parametrize("state", [pipeline_stages_base.ReconnectState.NEVER_CONNECTED]) + @pytest.mark.it( + "Does not create a reconnect timer if the connection fails with a transient error" + ) + def test_does_not_create_reconnect_timer_on_transient_connect_exception( + self, stage, connect_op, state, mock_timer, transient_connect_exception + ): + stage.state = state + connect_op.complete(error=transient_connect_exception) + assert mock_timer.call_count == 0 - op.callback.reset_mock() - operation_flow.pass_event_to_previous_stage(stage.next, resp) - assert op.callback.call_count == 0 - assert unhandled_error_handler.call_count == 0 + @pytest.mark.parametrize( + "state", + [ + pipeline_stages_base.ReconnectState.NEVER_CONNECTED, + pipeline_stages_base.ReconnectState.WAITING_TO_RECONNECT, + ], + ) + @pytest.mark.it("Does not change state if the connection fails with a transient error") + def test_does_not_change_state_on_transient_connect_exception( + self, stage, connect_op, state, transient_connect_exception + ): + stage.state = state + connect_op.complete(error=transient_connect_exception) + assert stage.state == state - @pytest.mark.it("Does nothing if an IotResponse with an unknown request_id is received") - def test_ignores_unknown_request_id(self, stage, op, iot_response, unhandled_error_handler): - iot_response.request_id = fake_request_id - operation_flow.pass_event_to_previous_stage(stage.next, iot_response) - assert op.callback.call_count == 0 - assert unhandled_error_handler.call_count == 0 + @pytest.mark.parametrize( + "state", [pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED] + ) + @pytest.mark.it( + "Changes the state to WAITING_TO_RECONNECT if the connection fails with a transient error" + ) + def test_changes_state_on_transient_connect_exception( + self, stage, connect_op, state, transient_connect_exception + ): + stage.state = state + connect_op.complete(error=transient_connect_exception) + assert stage.state == pipeline_stages_base.ReconnectState.WAITING_TO_RECONNECT + + @pytest.mark.parametrize( + "state", + [ + pipeline_stages_base.ReconnectState.CONNECTED_OR_DISCONNECTED, + pipeline_stages_base.ReconnectState.WAITING_TO_RECONNECT, + ], + ) + @pytest.mark.it("Starts a new reconnect timer if the connection fails with a transient error") + def test_starts_reconnect_timer_on_transient_connect_exception( + self, stage, connect_op, state, transient_connect_exception, mock_timer + ): + stage.state = state + connect_op.complete(error=transient_connect_exception) + assert mock_timer.call_count == 1 + assert mock_timer.return_value.start.call_count == 1 diff --git a/azure-iot-device/tests/common/pipeline/test_pipeline_stages_http.py b/azure-iot-device/tests/common/pipeline/test_pipeline_stages_http.py new file mode 100644 index 000000000..f72812617 --- /dev/null +++ b/azure-iot-device/tests/common/pipeline/test_pipeline_stages_http.py @@ -0,0 +1,324 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import logging +import pytest +import sys +import six +from azure.iot.device.common import transport_exceptions, handle_exceptions +from azure.iot.device.common.pipeline import ( + pipeline_ops_base, + pipeline_stages_base, + pipeline_ops_http, + pipeline_stages_http, + pipeline_exceptions, + config, +) +from tests.common.pipeline.helpers import StageRunOpTestBase +from tests.common.pipeline import pipeline_stage_test + + +this_module = sys.modules[__name__] +logging.basicConfig(level=logging.DEBUG) +pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") + +################### +# COMMON FIXTURES # +################### + + +@pytest.fixture +def mock_transport(mocker): + return mocker.patch( + "azure.iot.device.common.pipeline.pipeline_stages_http.HTTPTransport", autospec=True + ) + + +# Not a fixture, but used in parametrization +def fake_callback(): + pass + + +######################## +# HTTP TRANSPORT STAGE # +######################## + + +class HTTPTransportStageTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_stages_http.HTTPTransportStage + + @pytest.fixture + def init_kwargs(self, mocker): + return {} + + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + stage.send_op_down = mocker.MagicMock() + return stage + + +class HTTPTransportInstantiationTests(HTTPTransportStageTestConfig): + @pytest.mark.it("Initializes 'sas_token' attribute as None") + def test_sas_token(self, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + assert stage.sas_token is None + + @pytest.mark.it("Initializes 'transport' attribute as None") + def test_transport(self, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + assert stage.transport is None + + +pipeline_stage_test.add_base_pipeline_stage_tests( + test_module=this_module, + stage_class_under_test=pipeline_stages_http.HTTPTransportStage, + stage_test_config_class=HTTPTransportStageTestConfig, + extended_stage_instantiation_test_class=HTTPTransportInstantiationTests, +) + + +@pytest.mark.describe( + "HTTPTransportStage - .run_op() -- Called with SetHTTPConnectionArgsOperation" +) +class TestHTTPTransportStageRunOpCalledWithSetHTTPConnectionArgsOperation( + HTTPTransportStageTestConfig, StageRunOpTestBase +): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_http.SetHTTPConnectionArgsOperation( + hostname="fake_hostname", + server_verification_cert="fake_server_verification_cert", + client_cert="fake_client_cert", + sas_token="fake_sas_token", + callback=mocker.MagicMock(), + ) + + @pytest.mark.it("Stores the sas_token operation in the 'sas_token' attribute of the stage") + def test_stores_data(self, stage, op, mocker, mock_transport): + stage.run_op(op) + assert stage.sas_token == op.sas_token + + # TODO: Should probably remove the requirement to set it on the root. This seems only needed by Horton + @pytest.mark.it( + "Creates an HTTPTransport object and sets it as the 'transport' attribute of the stage (and on the pipeline root)" + ) + def test_creates_transport(self, mocker, stage, op, mock_transport): + assert stage.transport is None + + stage.run_op(op) + + assert mock_transport.call_count == 1 + assert mock_transport.call_args == mocker.call( + hostname=op.hostname, + server_verification_cert=op.server_verification_cert, + x509_cert=op.client_cert, + ) + assert stage.transport is mock_transport.return_value + + @pytest.mark.it("Completes the operation with success, upon successful execution") + def test_succeeds(self, mocker, stage, op, mock_transport): + assert not op.completed + stage.run_op(op) + assert op.completed + + +# NOTE: The HTTPTransport object is not instantiated upon instantiation of the HTTPTransportStage. +# It is only added once the SetHTTPConnectionArgsOperation runs. +# The lifecycle of the HTTPTransportStage is as follows: +# 1. Instantiate the stage +# 2. Configure the stage with a SetHTTPConnectionArgsOperation +# 3. Run any other desired operations. +# +# This is to say, no operation should be running before SetHTTPConnectionArgsOperation. +# Thus, for the following tests, we will assume that the HTTPTransport has already been created, +# and as such, the stage fixture used will have already have one. +class HTTPTransportStageTestConfigComplex(HTTPTransportStageTestConfig): + # We add a pytest fixture parametrization between SAS an X509 since depending on the version of authentication, the op will be formatted differently. + @pytest.fixture(params=["SAS", "X509"]) + def stage(self, mocker, request, cls_type, init_kwargs): + mock_transport = mocker.patch( + "azure.iot.device.common.pipeline.pipeline_stages_http.HTTPTransport", autospec=True + ) + stage = cls_type(**init_kwargs) + stage.send_op_down = mocker.MagicMock() + # Set up the Transport on the stage + if request.param == "SAS": + op = pipeline_ops_http.SetHTTPConnectionArgsOperation( + hostname="fake_hostname", + server_verification_cert="fake_server_verification_cert", + sas_token="fake_sas_token", + callback=mocker.MagicMock(), + ) + else: + op = pipeline_ops_http.SetHTTPConnectionArgsOperation( + hostname="fake_hostname", + server_verification_cert="fake_server_verification_cert", + client_cert="fake_client_cert", + callback=mocker.MagicMock(), + ) + stage.run_op(op) + assert stage.transport is mock_transport.return_value + + return stage + + +@pytest.mark.describe("HTTPTransportStage - .run_op() -- Called with UpdateSasTokenOperation") +class TestHTTPTransportStageRunOpCalledWithUpdateSasTokenOperation( + HTTPTransportStageTestConfigComplex, StageRunOpTestBase +): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_base.UpdateSasTokenOperation( + sas_token="new_fake_sas_token", callback=mocker.MagicMock() + ) + + @pytest.mark.it( + "Updates the 'sas_token' attribute to be the new value contained in the operation" + ) + def test_updates_token(self, stage, op): + assert stage.sas_token != op.sas_token + stage.run_op(op) + assert stage.sas_token == op.sas_token + + @pytest.mark.it("Completes the operation with success, upon successful execution") + def test_completes_op(self, stage, op): + assert not op.completed + stage.run_op(op) + assert op.completed + + +fake_method = "__fake_method__" +fake_path = "__fake_path__" +fake_headers = {"__fake_key__": "__fake_value__"} +fake_body = "__fake_body__" +fake_query_params = "__fake_query_params__" +fake_sas_token = "fake_sas_token" + + +@pytest.mark.describe( + "HTTPTransportStage - .run_op() -- Called with HTTPRequestAndResponseOperation" +) +class TestHTTPTransportStageRunOpCalledWithHTTPRequestAndResponseOperation( + HTTPTransportStageTestConfigComplex, StageRunOpTestBase +): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_http.HTTPRequestAndResponseOperation( + method=fake_method, + path=fake_path, + headers=fake_headers, + body=fake_body, + query_params=fake_query_params, + callback=mocker.MagicMock(), + ) + + @pytest.mark.it("Sends an HTTP request via the HTTPTransport") + def test_http_request(self, mocker, stage, op): + stage.run_op(op) + # We add this because the default stage here contains a SAS Token. + fake_headers["Authorization"] = fake_sas_token + assert stage.transport.request.call_count == 1 + assert stage.transport.request.call_args == mocker.call( + method=fake_method, + path=fake_path, + headers=fake_headers, + body=fake_body, + query_params=fake_query_params, + callback=mocker.ANY, + ) + + @pytest.mark.it( + "Does not provide an Authorization header if the SAS Token is not set in the stage" + ) + def test_header_with_no_sas(self, mocker, stage, op): + # Manually overwriting stage with no SAS Token. + stage.sas_token = None + stage.run_op(op) + assert stage.transport.request.call_count == 1 + assert stage.transport.request.call_args == mocker.call( + method=fake_method, + path=fake_path, + headers=fake_headers, + body=fake_body, + query_params=fake_query_params, + callback=mocker.ANY, + ) + + @pytest.mark.it( + "Completes the operation unsucessfully if there is a failure requesting via the HTTPTransport, using the error raised by the HTTPTransport" + ) + def test_fails_operation(self, mocker, stage, op, arbitrary_exception): + stage.transport.request.side_effect = arbitrary_exception + stage.run_op(op) + assert op.completed + assert op.error is arbitrary_exception + + @pytest.mark.it( + "Completes the operation successfully if the request invokes the provided callback without an error" + ) + def test_completes_callback(self, mocker, stage, op): + def mock_request_callback(method, path, headers, query_params, body, callback): + fake_response = { + "resp": "__fake_response__".encode("utf-8"), + "status_code": "__fake_status_code__", + "reason": "__fake_reason__", + } + return callback(response=fake_response) + + # This is a way for us to mock the transport invoking the callback + stage.transport.request.side_effect = mock_request_callback + stage.run_op(op) + assert op.completed + + @pytest.mark.it( + "Adds a reason, status code, and response body to the op if request invokes the provided callback without an error" + ) + def test_formats_op_on_complete(self, mocker, stage, op): + def mock_request_callback(method, path, headers, query_params, body, callback): + fake_response = { + "resp": "__fake_response__".encode("utf-8"), + "status_code": "__fake_status_code__", + "reason": "__fake_reason__", + } + return callback(response=fake_response) + + # This is a way for us to mock the transport invoking the callback + stage.transport.request.side_effect = mock_request_callback + stage.run_op(op) + assert op.reason == "__fake_reason__" + assert op.response_body == "__fake_response__".encode("utf-8") + assert op.status_code == "__fake_status_code__" + + @pytest.mark.it( + "Completes the operation with an error if the request invokes the provided callback with the same error" + ) + def test_completes_callback_with_error(self, mocker, stage, op, arbitrary_exception): + def mock_on_response_complete(method, path, headers, query_params, body, callback): + return callback(error=arbitrary_exception) + + stage.transport.request.side_effect = mock_on_response_complete + stage.run_op(op) + assert op.completed + assert op.error is arbitrary_exception + + +# NOTE: This is not something that should ever happen in correct program flow +# There should be no operations that make it to the HTTPTransportStage that are not handled by it +@pytest.mark.describe("HTTPTransportStage - .run_op() -- called with arbitrary other operation") +class TestHTTPTransportStageRunOpCalledWithArbitraryOperation( + HTTPTransportStageTestConfigComplex, StageRunOpTestBase +): + @pytest.fixture + def op(self, arbitrary_op): + return arbitrary_op + + @pytest.mark.it("Sends the operation down") + def test_sends_op_down(self, mocker, stage, op): + stage.run_op(op) + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) diff --git a/azure-iot-device/tests/common/pipeline/test_pipeline_stages_mqtt.py b/azure-iot-device/tests/common/pipeline/test_pipeline_stages_mqtt.py index 90a991006..664ae7b45 100644 --- a/azure-iot-device/tests/common/pipeline/test_pipeline_stages_mqtt.py +++ b/azure-iot-device/tests/common/pipeline/test_pipeline_stages_mqtt.py @@ -7,696 +7,956 @@ import logging import pytest import sys import six -from azure.iot.device.common import errors, unhandled_exceptions +from azure.iot.device.common import transport_exceptions, handle_exceptions from azure.iot.device.common.pipeline import ( pipeline_ops_base, pipeline_stages_base, pipeline_ops_mqtt, + pipeline_events_base, pipeline_events_mqtt, pipeline_stages_mqtt, + pipeline_exceptions, + config, ) -from tests.common.pipeline.helpers import ( - assert_callback_failed, - assert_callback_succeeded, - all_common_ops, - all_common_events, - all_except, - UnhandledException, -) +from tests.common.pipeline.helpers import StageRunOpTestBase from tests.common.pipeline import pipeline_stage_test +this_module = sys.modules[__name__] logging.basicConfig(level=logging.DEBUG) +pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") + +################### +# COMMON FIXTURES # +################### -# This fixture makes it look like all test in this file tests are running -# inside the pipeline thread. Because this is an autouse fixture, we -# manually add it to the individual test.py files that need it. If, -# instead, we had added it to some conftest.py, it would be applied to -# every tests in every file and we don't want that. -@pytest.fixture(autouse=True) -def apply_fake_pipeline_thread(fake_pipeline_thread): +@pytest.fixture +def mock_transport(mocker): + return mocker.patch( + "azure.iot.device.common.pipeline.pipeline_stages_mqtt.MQTTTransport", autospec=True + ) + + +# Not a fixture, but used in parametrization +def fake_callback(): pass -this_module = sys.modules[__name__] +######################## +# MQTT TRANSPORT STAGE # +######################## -fake_client_id = "__fake_client_id__" -fake_hostname = "__fake_hostname__" -fake_username = "__fake_username__" -fake_ca_cert = "__fake_ca_cert__" -fake_sas_token = "__fake_sas_token__" -fake_topic = "__fake_topic__" -fake_payload = "__fake_payload__" -fake_certificate = "__fake_certificate__" -ops_handled_by_this_stage = [ - pipeline_ops_base.ConnectOperation, - pipeline_ops_base.DisconnectOperation, - pipeline_ops_base.ReconnectOperation, - pipeline_ops_base.UpdateSasTokenOperation, - pipeline_ops_mqtt.SetMQTTConnectionArgsOperation, - pipeline_ops_mqtt.MQTTPublishOperation, - pipeline_ops_mqtt.MQTTSubscribeOperation, - pipeline_ops_mqtt.MQTTUnsubscribeOperation, -] +class MQTTTransportStageTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_stages_mqtt.MQTTTransportStage + + @pytest.fixture + def init_kwargs(self, mocker): + return {} + + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + stage.pipeline_root = pipeline_stages_base.PipelineRootStage( + pipeline_configuration=mocker.MagicMock() + ) + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + return stage + + +class MQTTTransportInstantiationTests(MQTTTransportStageTestConfig): + @pytest.mark.it("Initializes 'sas_token' attribute as None") + def test_sas_token(self, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + assert stage.sas_token is None + + @pytest.mark.it("Initializes 'transport' attribute as None") + def test_transport(self, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + assert stage.transport is None + + @pytest.mark.it("Initializes with no pending connection operation") + def test_pending_op(self, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + assert stage._pending_connection_op is None -events_handled_by_this_stage = [] -# TODO: Potentially refactor this out to package level class that can be inherited pipeline_stage_test.add_base_pipeline_stage_tests( - cls=pipeline_stages_mqtt.MQTTTransportStage, - module=this_module, - all_ops=all_common_ops, - handled_ops=ops_handled_by_this_stage, - all_events=all_common_events, - handled_events=events_handled_by_this_stage, - methods_that_enter_pipeline_thread=[ - "_on_mqtt_message_received", - "_on_mqtt_connected", - "_on_mqtt_connection_failure", - "_on_mqtt_disconnected", - ], + test_module=this_module, + stage_class_under_test=pipeline_stages_mqtt.MQTTTransportStage, + stage_test_config_class=MQTTTransportStageTestConfig, + extended_stage_instantiation_test_class=MQTTTransportInstantiationTests, ) -@pytest.fixture -def stage(mocker): - stage = pipeline_stages_mqtt.MQTTTransportStage() - root = pipeline_stages_base.PipelineRootStage() - - stage.previous = root - root.next = stage - stage.pipeline_root = root - - mocker.spy(root, "handle_pipeline_event") - mocker.spy(root, "on_connected") - mocker.spy(root, "on_disconnected") - - mocker.spy(stage, "_on_mqtt_connected") - mocker.spy(stage, "_on_mqtt_connection_failure") - mocker.spy(stage, "_on_mqtt_disconnected") - - return stage - - -@pytest.fixture -def transport(mocker): - mocker.patch( - "azure.iot.device.common.pipeline.pipeline_stages_mqtt.MQTTTransport", autospec=True - ) - return pipeline_stages_mqtt.MQTTTransport - - -@pytest.fixture -def op_set_connection_args(mocker): - return pipeline_ops_mqtt.SetMQTTConnectionArgsOperation( - client_id=fake_client_id, - hostname=fake_hostname, - username=fake_username, - ca_cert=fake_ca_cert, - client_cert=fake_certificate, - sas_token=fake_sas_token, - callback=mocker.MagicMock(), - ) - - -@pytest.fixture -def op_connect(mocker): - return pipeline_ops_base.ConnectOperation(callback=mocker.MagicMock()) - - -@pytest.fixture -def op_reconnect(mocker): - return pipeline_ops_base.ReconnectOperation(callback=mocker.MagicMock()) - - -@pytest.fixture -def op_disconnect(mocker): - return pipeline_ops_base.DisconnectOperation(callback=mocker.MagicMock()) - - -@pytest.fixture -def op_publish(mocker): - return pipeline_ops_mqtt.MQTTPublishOperation( - topic=fake_topic, payload=fake_payload, callback=mocker.MagicMock() - ) - - -@pytest.fixture -def op_subscribe(mocker): - return pipeline_ops_mqtt.MQTTSubscribeOperation(topic=fake_topic, callback=mocker.MagicMock()) - - -@pytest.fixture -def op_unsubscribe(mocker): - return pipeline_ops_mqtt.MQTTUnsubscribeOperation(topic=fake_topic, callback=mocker.MagicMock()) - - -@pytest.fixture -def create_transport(stage, transport, op_set_connection_args): - stage.run_op(op_set_connection_args) - - -# TODO: This should be a package level class inherited by all .run_op() tests in all stages -class RunOpTests(object): - @pytest.mark.it( - "Completes the operation with failure if an unexpected Exception is raised while executing the operation" - ) - def test_completes_operation_with_error(self, mocker, stage): - execution_exception = Exception() - mock_op = mocker.MagicMock() - stage._execute_op = mocker.MagicMock(side_effect=execution_exception) - - stage.run_op(mock_op) - assert mock_op.error is execution_exception - - @pytest.mark.it( - "Allows any BaseException that was raised during execution of the operation to propogate" - ) - def test_base_exception_propogates(self, mocker, stage): - execution_exception = BaseException() - mock_op = mocker.MagicMock() - stage._execute_op = mocker.MagicMock(side_effect=execution_exception) - - with pytest.raises(BaseException): - stage.run_op(mock_op) - - @pytest.mark.describe( - "MQTTTransportStage - .run_op() -- called with pipeline_ops_mqtt.SetMQTTConnectionArgsOperation" + "MQTTTransportStage - .run_op() -- Called with SetMQTTConnectionArgsOperation" ) -class TestMQTTProviderRunOpWithSetConnectionArgs(RunOpTests): - @pytest.mark.it("Creates an MQTTTransport object") - def test_creates_transport(self, stage, transport, op_set_connection_args): - stage.run_op(op_set_connection_args) - assert transport.call_count == 1 +class TestMQTTTransportStageRunOpCalledWithSetMQTTConnectionArgsOperation( + MQTTTransportStageTestConfig, StageRunOpTestBase +): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_mqtt.SetMQTTConnectionArgsOperation( + client_id="fake_client_id", + hostname="fake_hostname", + username="fake_username", + server_verification_cert="fake_server_verification_cert", + client_cert="fake_client_cert", + sas_token="fake_sas_token", + callback=mocker.MagicMock(), + ) + + @pytest.mark.it("Stores the sas_token operation in the 'sas_token' attribute of the stage") + def test_stores_data(self, stage, op, mocker, mock_transport): + stage.run_op(op) + assert stage.sas_token == op.sas_token @pytest.mark.it( - "Initializes the MQTTTransport object with the passed client_id, hostname, username, ca_cert and x509_cert" + "Creates an MQTTTransport object and sets it as the 'transport' attribute of the stage" ) - def test_passes_right_params(self, stage, transport, mocker, op_set_connection_args): - stage.run_op(op_set_connection_args) - assert transport.call_args == mocker.call( - client_id=fake_client_id, - hostname=fake_hostname, - username=fake_username, - ca_cert=fake_ca_cert, - x509_cert=fake_certificate, - ) + @pytest.mark.parametrize( + "websockets", + [ + pytest.param(True, id="Pipeline configured for websockets"), + pytest.param(False, id="Pipeline NOT configured for websockets"), + ], + ) + @pytest.mark.parametrize( + "cipher", + [ + pytest.param("DHE-RSA-AES128-SHA", id="Pipeline configured for custom cipher"), + pytest.param( + "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256", + id="Pipeline configured for multiple custom ciphers", + ), + pytest.param("", id="Pipeline NOT configured for custom cipher(s)"), + ], + ) + @pytest.mark.parametrize( + "proxy_options", + [ + pytest.param("FAKE-PROXY", id="Proxy present"), + pytest.param(None, id="Proxy None"), + pytest.param("", id="Proxy Absent"), + ], + ) + def test_creates_transport( + self, mocker, stage, op, mock_transport, websockets, cipher, proxy_options + ): + # Configure websockets & cipher + stage.pipeline_root.pipeline_configuration.websockets = websockets + stage.pipeline_root.pipeline_configuration.cipher = cipher + stage.pipeline_root.pipeline_configuration.proxy_options = proxy_options - @pytest.mark.it("Sets handlers on the transport") - def test_sets_parameters(self, stage, transport, mocker, op_set_connection_args): - stage.run_op(op_set_connection_args) - assert transport.return_value.on_mqtt_disconnected_handler == stage._on_mqtt_disconnected - assert transport.return_value.on_mqtt_connected_handler == stage._on_mqtt_connected - assert ( - transport.return_value.on_mqtt_connection_failure_handler - == stage._on_mqtt_connection_failure - ) - assert ( - transport.return_value.on_mqtt_message_received_handler - == stage._on_mqtt_message_received - ) + assert stage.transport is None - @pytest.mark.it("Sets the pending connection op tracker to None") - def test_pending_conn_op(self, stage, transport, op_set_connection_args): - stage.run_op(op_set_connection_args) + stage.run_op(op) + + assert mock_transport.call_count == 1 + assert mock_transport.call_args == mocker.call( + client_id=op.client_id, + hostname=op.hostname, + username=op.username, + server_verification_cert=op.server_verification_cert, + x509_cert=op.client_cert, + websockets=websockets, + cipher=cipher, + proxy_options=proxy_options, + ) + assert stage.transport is mock_transport.return_value + + @pytest.mark.it("Sets event handlers on the newly created MQTTTransport") + def test_sets_transport_handlers(self, mocker, stage, op, mock_transport): + stage.run_op(op) + + assert stage.transport.on_mqtt_disconnected_handler == stage._on_mqtt_disconnected + assert stage.transport.on_mqtt_connected_handler == stage._on_mqtt_connected + assert ( + stage.transport.on_mqtt_connection_failure_handler == stage._on_mqtt_connection_failure + ) + assert stage.transport.on_mqtt_message_received_handler == stage._on_mqtt_message_received + + # CT-TODO: does this even need to be happening in this stage? Shouldn't this be part of init? + @pytest.mark.it("Sets the stage's pending connection operation to None") + def test_pending_conn_op(self, stage, op, mock_transport): + stage.run_op(op) assert stage._pending_connection_op is None @pytest.mark.it("Completes the operation with success, upon successful execution") - def test_succeeds(self, stage, transport, op_set_connection_args): - stage.run_op(op_set_connection_args) - assert_callback_succeeded(op=op_set_connection_args) + def test_succeeds(self, mocker, stage, op, mock_transport): + assert not op.completed + stage.run_op(op) + assert op.completed -@pytest.mark.describe("MQTTTransportStage - .run_op() -- called with ConnectOperation") -class TestMQTTProviderExecuteOpWithConnect(RunOpTests): - @pytest.mark.it("Sets the ConnectOperation as the pending connection operation") - def test_sets_pending_operation(self, stage, create_transport, op_connect): - stage.run_op(op_connect) - assert stage._pending_connection_op is op_connect +# NOTE: The MQTTTransport object is not instantiated upon instantiation of the MQTTTransportStage. +# It is only added once the SetMQTTConnectionArgsOperation runs. +# The lifecycle of the MQTTTransportStage is as follows: +# 1. Instantiate the stage +# 2. Configure the stage with a SetMQTTConnectionArgsOperation +# 3. Run any other desired operations. +# +# This is to say, no operation should be running before SetMQTTConnectionArgsOperation. +# Thus, for the following tests, we will assume that the MQTTTransport has already been created, +# and as such, the stage fixture used will have already have one. +class MQTTTransportStageTestConfigComplex(MQTTTransportStageTestConfig): + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs, mock_transport): + stage = cls_type(**init_kwargs) + stage.pipeline_root = pipeline_stages_base.PipelineRootStage( + pipeline_configuration=mocker.MagicMock() + ) + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + + # Set up the Transport on the stage + op = pipeline_ops_mqtt.SetMQTTConnectionArgsOperation( + client_id="fake_client_id", + hostname="fake_hostname", + username="fake_username", + server_verification_cert="fake_server_verification_cert", + client_cert="fake_client_cert", + sas_token="fake_sas_token", + callback=mocker.MagicMock(), + ) + stage.run_op(op) + assert stage.transport is mock_transport.return_value + + return stage + + +@pytest.mark.describe("MQTTTransportStage - .run_op() -- Called with UpdateSasTokenOperation") +class TestMQTTTransportStageRunOpCalledWithUpdateSasTokenOperation( + MQTTTransportStageTestConfigComplex, StageRunOpTestBase +): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_base.UpdateSasTokenOperation( + sas_token="new_fake_sas_token", callback=mocker.MagicMock() + ) + + @pytest.mark.it( + "Updates the 'sas_token' attribute to be the new value contained in the operation" + ) + def test_updates_token(self, stage, op): + assert stage.sas_token != op.sas_token + stage.run_op(op) + assert stage.sas_token == op.sas_token + + @pytest.mark.it("Completes the operation with success, upon successful execution") + def test_complets_op(self, stage, op): + assert not op.completed + stage.run_op(op) + assert op.completed + + +@pytest.mark.describe("MQTTTransportStage - .run_op() -- Called with ConnectOperation") +class TestMQTTTransportStageRunOpCalledWithConnectOperation( + MQTTTransportStageTestConfigComplex, StageRunOpTestBase +): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_base.ConnectOperation(callback=mocker.MagicMock()) + + @pytest.mark.it("Sets the operation as the stage's pending connection operation") + def test_sets_pending_operation(self, stage, op): + stage.run_op(op) + assert stage._pending_connection_op is op @pytest.mark.it("Cancels any already pending connection operation") @pytest.mark.parametrize( "pending_connection_op", [ - pytest.param(pipeline_ops_base.ConnectOperation(), id="Pending ConnectOperation"), - pytest.param(pipeline_ops_base.ReconnectOperation(), id="Pending ReconnectOperation"), - pytest.param(pipeline_ops_base.DisconnectOperation(), id="Pending DisconnectOperation"), + pytest.param( + pipeline_ops_base.ConnectOperation(callback=fake_callback), + id="Pending ConnectOperation", + ), + pytest.param( + pipeline_ops_base.ReauthorizeConnectionOperation(callback=fake_callback), + id="Pending ReauthorizeConnectOperation", + ), + pytest.param( + pipeline_ops_base.DisconnectOperation(callback=fake_callback), + id="Pending DisconnectOperation", + ), ], ) - def test_pending_operation_cancelled( - self, mocker, stage, create_transport, op_connect, pending_connection_op - ): - pending_connection_op.callback = mocker.MagicMock() + def test_pending_operation_cancelled(self, mocker, stage, op, pending_connection_op): + # Set up a pending op stage._pending_connection_op = pending_connection_op - stage.run_op(op_connect) + assert not pending_connection_op.completed - # Callback has been completed, with a PipelineError set indicating early cancellation - assert_callback_failed(op=pending_connection_op, error=errors.PipelineError) + # Run the connect op + stage.run_op(op) + + # Operation has been completed, with an OperationCancelled exception set indicating early cancellation + assert pending_connection_op.completed + assert type(pending_connection_op.error) is pipeline_exceptions.OperationCancelled # New operation is now the pending operation - assert stage._pending_connection_op is op_connect + assert stage._pending_connection_op is op - @pytest.mark.it("Does an MQTT connect via the MQTTTransport") - def test_mqtt_connect(self, mocker, stage, create_transport, op_connect): - stage.run_op(op_connect) + @pytest.mark.it("Performs an MQTT connect via the MQTTTransport") + def test_mqtt_connect(self, mocker, stage, op): + stage.run_op(op) assert stage.transport.connect.call_count == 1 assert stage.transport.connect.call_args == mocker.call(password=stage.sas_token) @pytest.mark.it( - "Fails the operation and resets the pending connection operation to None, if there is a failure connecting in the MQTTTransport" + "Completes the operation unsucessfully if there is a failure connecting via the MQTTTransport, using the error raised by the MQTTTransport" ) - def test_fails_operation(self, stage, create_transport, op_connect, fake_exception): - stage.transport.connect.side_effect = fake_exception - stage.run_op(op_connect) - assert_callback_failed(op=op_connect, error=fake_exception) - assert stage._pending_connection_op is None - - -@pytest.mark.describe("MQTTTransportStage - .run_op() -- called with ReconnectOperation") -class TestMQTTProviderExecuteOpWithReconnect(RunOpTests): - @pytest.mark.it("Sets the ReconnectOperation as the pending connection operation") - def test_sets_pending_operation(self, stage, create_transport, op_reconnect): - stage.run_op(op_reconnect) - assert stage._pending_connection_op is op_reconnect - - @pytest.mark.it("Cancels any already pending connection operation") - @pytest.mark.parametrize( - "pending_connection_op", - [ - pytest.param(pipeline_ops_base.ConnectOperation(), id="Pending ConnectOperation"), - pytest.param(pipeline_ops_base.ReconnectOperation(), id="Pending ReconnectOperation"), - pytest.param(pipeline_ops_base.DisconnectOperation(), id="Pending DisconnectOperation"), - ], - ) - def test_pending_operation_cancelled( - self, mocker, stage, create_transport, op_reconnect, pending_connection_op - ): - pending_connection_op.callback = mocker.MagicMock() - stage._pending_connection_op = pending_connection_op - stage.run_op(op_reconnect) - - # Callback has been completed, with a PipelineError set indicating early cancellation - assert_callback_failed(op=pending_connection_op, error=errors.PipelineError) - - # New operation is now the pending operation - assert stage._pending_connection_op is op_reconnect - - @pytest.mark.it("Does an MQTT reconnect via the MQTTTransport") - def test_mqtt_reconnect(self, mocker, stage, create_transport, op_reconnect): - stage.run_op(op_reconnect) - assert stage.transport.reconnect.call_count == 1 - assert stage.transport.reconnect.call_args == mocker.call(password=stage.sas_token) + def test_fails_operation(self, mocker, stage, op, arbitrary_exception): + stage.transport.connect.side_effect = arbitrary_exception + stage.run_op(op) + assert op.completed + assert op.error is arbitrary_exception @pytest.mark.it( - "Fails the operation and resets the pending connection operation to None, if there is a failure reconnecting in the MQTTTransport" + "Resets the stage's pending connection operation to None, if there is a failure connecting via the MQTTTransport" ) - def test_fails_operation(self, mocker, stage, create_transport, op_reconnect, fake_exception): - stage.transport.reconnect.side_effect = fake_exception - stage.run_op(op_reconnect) - assert_callback_failed(op=op_reconnect, error=fake_exception) + def test_clears_pending_op_on_failure(self, mocker, stage, op, arbitrary_exception): + stage.transport.connect.side_effect = arbitrary_exception + stage.run_op(op) assert stage._pending_connection_op is None -@pytest.mark.describe("MQTTTransportStage - .run_op() -- called with DisconnectOperation") -class TestMQTTProviderExecuteOpWithDisconnect(RunOpTests): - @pytest.mark.it("Sets the DisconnectOperation as the pending connection operation") - def test_sets_pending_operation(self, stage, create_transport, op_disconnect): - stage.run_op(op_disconnect) - assert stage._pending_connection_op is op_disconnect +@pytest.mark.describe( + "MQTTTransportStage - .run_op() -- Called with ReauthorizeConnectionOperation" +) +class TestMQTTTransportStageRunOpCalledWithReauthorizeConnectionOperation( + MQTTTransportStageTestConfigComplex, StageRunOpTestBase +): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_base.ReauthorizeConnectionOperation(callback=mocker.MagicMock()) + + @pytest.mark.it("Sets the operation as the stage's pending connection operation") + def test_sets_pending_operation(self, stage, op): + stage.run_op(op) + assert stage._pending_connection_op is op @pytest.mark.it("Cancels any already pending connection operation") @pytest.mark.parametrize( "pending_connection_op", [ - pytest.param(pipeline_ops_base.ConnectOperation(), id="Pending ConnectOperation"), - pytest.param(pipeline_ops_base.ReconnectOperation(), id="Pending ReconnectOperation"), - pytest.param(pipeline_ops_base.DisconnectOperation(), id="Pending DisconnectOperation"), + pytest.param( + pipeline_ops_base.ConnectOperation(callback=fake_callback), + id="Pending ConnectOperation", + ), + pytest.param( + pipeline_ops_base.ReauthorizeConnectionOperation(callback=fake_callback), + id="Pending ReauthorizeConnectionOperation", + ), + pytest.param( + pipeline_ops_base.DisconnectOperation(callback=fake_callback), + id="Pending DisconnectOperation", + ), ], ) - def test_pending_operation_cancelled( - self, mocker, stage, create_transport, op_disconnect, pending_connection_op - ): - pending_connection_op.callback = mocker.MagicMock() + def test_pending_operation_cancelled(self, mocker, stage, op, pending_connection_op): + # Set up a pending op stage._pending_connection_op = pending_connection_op - stage.run_op(op_disconnect) + assert not pending_connection_op.completed - # Callback has been completed, with a PipelineError set indicating early cancellation - assert_callback_failed(op=pending_connection_op, error=errors.PipelineError) + # Run the connect op + stage.run_op(op) + + # Operation has been completed, with an OperationCancelled exception set indicating early cancellation + assert pending_connection_op.completed + assert type(pending_connection_op.error) is pipeline_exceptions.OperationCancelled # New operation is now the pending operation - assert stage._pending_connection_op is op_disconnect + assert stage._pending_connection_op is op - @pytest.mark.it("Does an MQTT disconnect via the MQTTTransport") - def test_mqtt_disconnect(self, mocker, stage, create_transport, op_disconnect): - stage.run_op(op_disconnect) + @pytest.mark.it("Performs an MQTT reconnect via the MQTTTransport") + def test_mqtt_connect(self, mocker, stage, op): + stage.run_op(op) + assert stage.transport.reauthorize_connection.call_count == 1 + assert stage.transport.reauthorize_connection.call_args == mocker.call( + password=stage.sas_token + ) + + @pytest.mark.it( + "Completes the operation unsucessfully if there is a failure reconnecting via the MQTTTransport, using the error raised by the MQTTTransport" + ) + def test_fails_operation(self, mocker, stage, op, arbitrary_exception): + stage.transport.reauthorize_connection.side_effect = arbitrary_exception + stage.run_op(op) + assert op.completed + assert op.error is arbitrary_exception + + @pytest.mark.it( + "Resets the stage's pending connection operation to None, if there is a failure reconnecting via the MQTTTransport" + ) + def test_clears_pending_op_on_failure(self, mocker, stage, op, arbitrary_exception): + stage.transport.reauthorize_connection.side_effect = arbitrary_exception + stage.run_op(op) + assert stage._pending_connection_op is None + + +@pytest.mark.describe("MQTTTransportStage - .run_op() -- Called with DisconnectOperation") +class TestMQTTTransportStageRunOpCalledWithDisconnectOperation( + MQTTTransportStageTestConfigComplex, StageRunOpTestBase +): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_base.DisconnectOperation(callback=mocker.MagicMock()) + + @pytest.mark.it("Sets the operation as the stage's pending connection operation") + def test_sets_pending_operation(self, stage, op): + stage.run_op(op) + assert stage._pending_connection_op is op + + @pytest.mark.it("Cancels any already pending connection operation") + @pytest.mark.parametrize( + "pending_connection_op", + [ + pytest.param( + pipeline_ops_base.ConnectOperation(callback=fake_callback), + id="Pending ConnectOperation", + ), + pytest.param( + pipeline_ops_base.ReauthorizeConnectionOperation(callback=fake_callback), + id="Pending ReauthorizeConnectionOperation", + ), + pytest.param( + pipeline_ops_base.DisconnectOperation(callback=fake_callback), + id="Pending DisconnectOperation", + ), + ], + ) + def test_pending_operation_cancelled(self, mocker, stage, op, pending_connection_op): + # Set up a pending op + stage._pending_connection_op = pending_connection_op + assert not pending_connection_op.completed + + # Run the connect op + stage.run_op(op) + + # Operation has been completed, with an OperationCancelled exception set indicating early cancellation + assert pending_connection_op.completed + assert type(pending_connection_op.error) is pipeline_exceptions.OperationCancelled + + # New operation is now the pending operation + assert stage._pending_connection_op is op + + @pytest.mark.it("Performs an MQTT disconnect via the MQTTTransport") + def test_mqtt_connect(self, mocker, stage, op): + stage.run_op(op) assert stage.transport.disconnect.call_count == 1 assert stage.transport.disconnect.call_args == mocker.call() @pytest.mark.it( - "Fails the operation and resets the pending connection operation to None, if there is a failure disconnecting in the MQTTTransport" + "Completes the operation unsucessfully if there is a failure disconnecting via the MQTTTransport, using the error raised by the MQTTTransport" ) - def test_fails_operation(self, mocker, stage, create_transport, op_disconnect, fake_exception): - stage.transport.disconnect.side_effect = fake_exception - stage.run_op(op_disconnect) - assert_callback_failed(op=op_disconnect, error=fake_exception) + def test_fails_operation(self, mocker, stage, op, arbitrary_exception): + stage.transport.disconnect.side_effect = arbitrary_exception + stage.run_op(op) + assert op.completed + assert op.error is arbitrary_exception + + @pytest.mark.it( + "Resets the stage's pending connection operation to None, if there is a failure disconnecting via the MQTTTransport" + ) + def test_clears_pending_op_on_failure(self, mocker, stage, op, arbitrary_exception): + stage.transport.disconnect.side_effect = arbitrary_exception + stage.run_op(op) assert stage._pending_connection_op is None @pytest.mark.describe("MQTTTransportStage - .run_op() -- called with MQTTPublishOperation") -class TestMQTTProviderExecuteOpWithMQTTPublishOperation(RunOpTests): - @pytest.mark.it("Does an MQTT publish via the MQTTTransport") - def test_mqtt_publish(self, mocker, stage, create_transport, op_publish): - stage.run_op(op_publish) +class TestMQTTTransportStageRunOpCalledWithMQTTPublishOperation( + MQTTTransportStageTestConfigComplex, StageRunOpTestBase +): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_mqtt.MQTTPublishOperation( + topic="fake_topic", payload="fake_payload", callback=mocker.MagicMock() + ) + + @pytest.mark.it("Performs an MQTT publish via the MQTTTransport") + def test_mqtt_publish(self, mocker, stage, op): + stage.run_op(op) assert stage.transport.publish.call_count == 1 assert stage.transport.publish.call_args == mocker.call( - topic=op_publish.topic, payload=op_publish.payload, callback=mocker.ANY + topic=op.topic, payload=op.payload, callback=mocker.ANY ) @pytest.mark.it( - "Completes the operation with success, upon successful completion of the MQTT publish" + "Sucessfully completes the operation, upon successful completion of the MQTT publish by the MQTTTransport" ) - def test_complete(self, mocker, stage, create_transport, op_publish): + def test_complete(self, mocker, stage, op): # Begin publish - stage.run_op(op_publish) + stage.run_op(op) + + assert not op.completed # Trigger publish completion stage.transport.publish.call_args[1]["callback"]() - assert_callback_succeeded(op=op_publish) + assert op.completed + assert op.error is None @pytest.mark.describe("MQTTTransportStage - .run_op() -- called with MQTTSubscribeOperation") -class TestMQTTProviderExecuteOpWithMQTTSubscribeOperation(RunOpTests): - @pytest.mark.it("Does an MQTT subscribe via the MQTTTransport") - def test_mqtt_publish(self, mocker, stage, create_transport, op_subscribe): - stage.run_op(op_subscribe) +class TestMQTTTransportStageRunOpCalledWithMQTTSubscribeOperation( + MQTTTransportStageTestConfigComplex, StageRunOpTestBase +): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_mqtt.MQTTSubscribeOperation( + topic="fake_topic", callback=mocker.MagicMock() + ) + + @pytest.mark.it("Performs an MQTT subscribe via the MQTTTransport") + def test_mqtt_publish(self, mocker, stage, op): + stage.run_op(op) assert stage.transport.subscribe.call_count == 1 assert stage.transport.subscribe.call_args == mocker.call( - topic=op_subscribe.topic, callback=mocker.ANY + topic=op.topic, callback=mocker.ANY ) @pytest.mark.it( - "Completes the operation with success, upon successful completion of the MQTT subscribe" + "Sucessfully completes the operation, upon successful completion of the MQTT subscribe by the MQTTTransport" ) - def test_complete(self, mocker, stage, create_transport, op_subscribe): + def test_complete(self, mocker, stage, op): # Begin subscribe - stage.run_op(op_subscribe) + stage.run_op(op) + + assert not op.completed # Trigger subscribe completion stage.transport.subscribe.call_args[1]["callback"]() - assert_callback_succeeded(op=op_subscribe) + assert op.completed + assert op.error is None @pytest.mark.describe("MQTTTransportStage - .run_op() -- called with MQTTUnsubscribeOperation") -class TestMQTTProviderExecuteOpWithMQTTUnsubscribeOperation(RunOpTests): - @pytest.mark.it("Does an MQTT unsubscribe via the MQTTTransport") - def test_mqtt_publish(self, mocker, stage, create_transport, op_unsubscribe): - stage.run_op(op_unsubscribe) +class TestMQTTTransportStageRunOpCalledWithMQTTUnsubscribeOperation( + MQTTTransportStageTestConfigComplex, StageRunOpTestBase +): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_mqtt.MQTTUnsubscribeOperation( + topic="fake_topic", callback=mocker.MagicMock() + ) + + @pytest.mark.it("Performs an MQTT unsubscribe via the MQTTTransport") + def test_mqtt_publish(self, mocker, stage, op): + stage.run_op(op) assert stage.transport.unsubscribe.call_count == 1 assert stage.transport.unsubscribe.call_args == mocker.call( - topic=op_unsubscribe.topic, callback=mocker.ANY + topic=op.topic, callback=mocker.ANY ) @pytest.mark.it( - "Completes the operation with success, upon successful completion of the MQTT unsubscribe" + "Successfully completes the operation upon successful completion of the MQTT unsubscribe by the MQTTTransport" ) - def test_complete(self, mocker, stage, create_transport, op_unsubscribe): + def test_complete(self, mocker, stage, op): # Begin unsubscribe - stage.run_op(op_unsubscribe) + stage.run_op(op) + + assert not op.completed # Trigger unsubscribe completion stage.transport.unsubscribe.call_args[1]["callback"]() - assert_callback_succeeded(op=op_unsubscribe) + assert op.completed + assert op.error is None -fake_sas_token = "__FAKE_SAS_TOKEN__" +# NOTE: This is not something that should ever happen in correct program flow +# There should be no operations that make it to the MQTTTransportStage that are not handled by it +@pytest.mark.describe("MQTTTransportStage - .run_op() -- called with arbitrary other operation") +class TestMQTTTransportStageRunOpCalledWithArbitraryOperation( + MQTTTransportStageTestConfigComplex, StageRunOpTestBase +): + @pytest.fixture + def op(self, arbitrary_op): + return arbitrary_op + + @pytest.mark.it("Sends the operation down") + def test_sends_op_down(self, mocker, stage, op): + stage.run_op(op) + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) -@pytest.mark.describe("MQTTTransportStage - .run_op() -- called with UpdateSasTokenOperation") -class TestMQTTProviderExecuteOpWithUpdateSasTokenoperation(RunOpTests): - @pytest.mark.it("Saves the token and completes immediately") - def test_mqtt_publish(self, mocker, stage, create_transport): - cb = mocker.MagicMock() - op_update_sas_token = pipeline_ops_base.UpdateSasTokenOperation( - sas_token=fake_sas_token, callback=cb - ) - stage.run_op(op_update_sas_token) - assert_callback_succeeded(op_update_sas_token) - assert stage.sas_token == fake_sas_token +@pytest.mark.describe("MQTTTransportStage - OCCURANCE: MQTT message received") +class TestMQTTTransportStageProtocolClientEvents(MQTTTransportStageTestConfigComplex): + @pytest.mark.it("Sends an IncomingMQTTMessageEvent event up the pipeline") + def test_incoming_message_handler(self, stage, mocker): + # Trigger MQTT message received + stage.transport.on_mqtt_message_received_handler(topic="fake_topic", payload="fake_payload") + assert stage.send_event_up.call_count == 1 + event = stage.send_event_up.call_args[0][0] + assert isinstance(event, pipeline_events_mqtt.IncomingMQTTMessageEvent) -@pytest.mark.describe("MQTTTransportStage - EVENT: MQTT message received") -class TestMQTTProviderProtocolClientEvents(object): - @pytest.mark.it("Fires an IncomingMQTTMessageEvent event for each MQTT message received") - def test_incoming_message_handler(self, stage, create_transport, mocker): + @pytest.mark.it("Passes topic and payload as part of the IncomingMQTTMessageEvent") + def test_verify_incoming_message_attributes(self, stage, mocker): + fake_topic = "fake_topic" + fake_payload = "fake_payload" + + # Trigger MQTT message received stage.transport.on_mqtt_message_received_handler(topic=fake_topic, payload=fake_payload) - assert stage.previous.handle_pipeline_event.call_count == 1 - call_arg = stage.previous.handle_pipeline_event.call_args[0][0] - assert isinstance(call_arg, pipeline_events_mqtt.IncomingMQTTMessageEvent) - @pytest.mark.it("Passes topic and payload as part of the IncomingMQTTMessageEvent event") - def test_verify_incoming_message_attributes(self, stage, create_transport, mocker): - stage.transport.on_mqtt_message_received_handler(topic=fake_topic, payload=fake_payload) - call_arg = stage.previous.handle_pipeline_event.call_args[0][0] - assert call_arg.payload == fake_payload - assert call_arg.topic == fake_topic + event = stage.send_event_up.call_args[0][0] + assert event.payload == fake_payload + assert event.topic == fake_topic -@pytest.mark.describe("MQTTTransportStage - EVENT: MQTT connected") -class TestMQTTProviderOnConnected(object): - @pytest.mark.it("Calls self.on_connected when the transport connected event fires") +@pytest.mark.describe("MQTTTransportStage - OCCURANCE: MQTT connected") +class TestMQTTTransportStageOnConnected(MQTTTransportStageTestConfigComplex): + @pytest.mark.it("Sends a ConnectedEvent up the pipeline") @pytest.mark.parametrize( "pending_connection_op", [ pytest.param(None, id="No pending operation"), - pytest.param(pipeline_ops_base.ConnectOperation(), id="Pending ConnectOperation"), - pytest.param(pipeline_ops_base.ReconnectOperation(), id="Pending ReconnectOperation"), - pytest.param(pipeline_ops_base.DisconnectOperation(), id="Pending DisconnectOperation"), + pytest.param(pipeline_ops_base.ConnectOperation(1), id="Pending ConnectOperation"), + pytest.param( + pipeline_ops_base.ReauthorizeConnectionOperation(1), + id="Pending ReauthorizeConnectionOperation", + ), + pytest.param( + pipeline_ops_base.DisconnectOperation(1), id="Pending DisconnectOperation" + ), ], ) - def test_connected_handler(self, stage, create_transport, pending_connection_op): + def test_sends_event_up(self, stage, pending_connection_op): stage._pending_connection_op = pending_connection_op - assert stage.previous.on_connected.call_count == 0 + # Trigger connect completion stage.transport.on_mqtt_connected_handler() - assert stage.previous.on_connected.call_count == 1 - @pytest.mark.it( - "Completes a pending ConnectOperation with success when the transport connected event fires" - ) - def test_completes_pending_connect_op(self, mocker, stage, create_transport): + assert stage.send_event_up.call_count == 1 + connect_event = stage.send_event_up.call_args[0][0] + assert isinstance(connect_event, pipeline_events_base.ConnectedEvent) + + @pytest.mark.it("Completes a pending ConnectOperation successfully") + def test_completes_pending_connect_op(self, mocker, stage): + # Set a pending connect operation op = pipeline_ops_base.ConnectOperation(callback=mocker.MagicMock()) stage.run_op(op) - assert op.callback.call_count == 0 + assert not op.completed assert stage._pending_connection_op is op + + # Trigger connect completion stage.transport.on_mqtt_connected_handler() - assert_callback_succeeded(op=op) + + # Connect operation completed successfully + assert op.completed + assert op.error is None assert stage._pending_connection_op is None - @pytest.mark.it( - "Completes a pending ReconnectOperation with success when the transport connected event fires" - ) - def test_completes_pending_reconnect_op(self, mocker, stage, create_transport): - op = pipeline_ops_base.ReconnectOperation(callback=mocker.MagicMock()) + @pytest.mark.it("Completes a pending ReauthorizeConnectionOperation successfully") + def test_completes_pending_reconnect_op(self, mocker, stage): + # Set a pending reconnect operation + op = pipeline_ops_base.ReauthorizeConnectionOperation(callback=mocker.MagicMock()) stage.run_op(op) - assert op.callback.call_count == 0 + assert not op.completed assert stage._pending_connection_op is op + + # Trigger connect completion stage.transport.on_mqtt_connected_handler() - assert_callback_succeeded(op=op) + + # Reconnect operation completed successfully + assert op.completed + assert op.error is None assert stage._pending_connection_op is None @pytest.mark.it( "Ignores a pending DisconnectOperation when the transport connected event fires" ) - def test_ignores_pending_disconnect_op(self, mocker, stage, create_transport): + def test_ignores_pending_disconnect_op(self, mocker, stage): + # Set a pending disconnect operation op = pipeline_ops_base.DisconnectOperation(callback=mocker.MagicMock()) stage.run_op(op) - assert op.callback.call_count == 0 + assert not op.completed assert stage._pending_connection_op is op + + # Trigger connect completion stage.transport.on_mqtt_connected_handler() - # handler did NOT trigger a callback - assert op.callback.call_count == 0 + + # Disconnect operation was NOT completed + assert not op.completed assert stage._pending_connection_op is op -@pytest.mark.describe("MQTTTarnsportStage - EVENT: MQTT connection failure") -class TestMQTTProviderOnConnectionFailure(object): - @pytest.mark.it( - "Does not call self.on_connected when the transport connection failure event fires" - ) +@pytest.mark.describe("MQTTTransportStage - OCCURANCE: MQTT connection failure") +class TestMQTTTransportStageOnConnectionFailure(MQTTTransportStageTestConfigComplex): + @pytest.mark.it("Does not send any events up the pipeline") @pytest.mark.parametrize( "pending_connection_op", [ pytest.param(None, id="No pending operation"), - pytest.param(pipeline_ops_base.ConnectOperation(), id="Pending ConnectOperation"), - pytest.param(pipeline_ops_base.ReconnectOperation(), id="Pending ReconnectOperation"), - pytest.param(pipeline_ops_base.DisconnectOperation(), id="Pending DisconnectOperation"), + pytest.param(pipeline_ops_base.ConnectOperation(1), id="Pending ConnectOperation"), + pytest.param( + pipeline_ops_base.ReauthorizeConnectionOperation(1), + id="Pending ReauthorizeConnectionOperation", + ), + pytest.param( + pipeline_ops_base.DisconnectOperation(1), id="Pending DisconnectOperation" + ), ], ) - def test_does_not_call_connected_handler( - self, stage, create_transport, fake_exception, pending_connection_op - ): - # This test is testing negative space - something the function does NOT do - rather than something it does + def test_does_not_send_event(self, mocker, stage, pending_connection_op, arbitrary_exception): stage._pending_connection_op = pending_connection_op - assert stage.previous.on_connected.call_count == 0 - stage.transport.on_mqtt_connection_failure_handler(fake_exception) - assert stage.previous.on_connected.call_count == 0 - @pytest.mark.it("Fails a pending ConnectOperation if the connection failure event fires") - def test_fails_pending_connect_op(self, mocker, stage, create_transport, fake_exception): + # Trigger connection failure with an arbitrary cause + stage.transport.on_mqtt_connection_failure_handler(arbitrary_exception) + + assert stage.send_event_up.call_count == 0 + + @pytest.mark.it( + "Completes a pending ConnectOperation unsuccessfully with the cause of connection failure as the error" + ) + def test_fails_pending_connect_op(self, mocker, stage, arbitrary_exception): + # Create a pending ConnectOperation op = pipeline_ops_base.ConnectOperation(callback=mocker.MagicMock()) stage.run_op(op) - assert op.callback.call_count == 0 + assert not op.completed assert stage._pending_connection_op is op - stage.transport.on_mqtt_connection_failure_handler(fake_exception) - assert_callback_failed(op=op, error=fake_exception) + + # Trigger connection failure with an arbitrary cause + stage.transport.on_mqtt_connection_failure_handler(arbitrary_exception) + + assert op.completed + assert op.error is arbitrary_exception assert stage._pending_connection_op is None - @pytest.mark.it("Fails a pending ReconnectOperation if the connection failure event fires") - def test_fails_pending_reconnect_op(self, mocker, stage, create_transport, fake_exception): - op = pipeline_ops_base.ReconnectOperation(callback=mocker.MagicMock()) + @pytest.mark.it( + "Completes a pending ReauthorizeConnectionOperation unsuccessfully with the cause of connection failure as the error" + ) + def test_fails_pending_reconnect_op(self, mocker, stage, arbitrary_exception): + # Create a pending ReauthorizeConnectionOperation + op = pipeline_ops_base.ReauthorizeConnectionOperation(callback=mocker.MagicMock()) stage.run_op(op) - assert op.callback.call_count == 0 + assert not op.completed assert stage._pending_connection_op is op - stage.transport.on_mqtt_connection_failure_handler(fake_exception) - assert_callback_failed(op=op, error=fake_exception) + + # Trigger connection failure with an arbitrary cause + stage.transport.on_mqtt_connection_failure_handler(arbitrary_exception) + + assert op.completed + assert op.error is arbitrary_exception assert stage._pending_connection_op is None - @pytest.mark.it("Ignores a pending DisconnectOperation if the connection failure event fires") - def test_ignores_pending_disconnect_op(self, mocker, stage, create_transport, fake_exception): + @pytest.mark.it("Ignores a pending DisconnectOperation, and does not complete it") + def test_ignores_pending_disconnect_op(self, mocker, stage, arbitrary_exception): + # Create a pending DisconnectOperation op = pipeline_ops_base.DisconnectOperation(callback=mocker.MagicMock()) stage.run_op(op) - assert op.callback.call_count == 0 + assert not op.completed assert stage._pending_connection_op is op - stage.transport.on_mqtt_connection_failure_handler(fake_exception) + + # Trigger connection failure with an arbitrary cause + stage.transport.on_mqtt_connection_failure_handler(arbitrary_exception) + # Assert nothing changed about the operation - assert op.callback.call_count == 0 + assert not op.completed assert stage._pending_connection_op is op @pytest.mark.it( - "Triggers the unhandled exception handler (with error cause) when the connection failure is unexpected" + "Triggers the swallowed exception handler (with error cause) when the connection failure is unexpected" ) @pytest.mark.parametrize( "pending_connection_op", [ pytest.param(None, id="No pending operation"), - pytest.param(pipeline_ops_base.DisconnectOperation(), id="Pending DisconnectOperation"), + pytest.param( + pipeline_ops_base.DisconnectOperation(callback=fake_callback), + id="Pending DisconnectOperation", + ), ], ) def test_unexpected_connection_failure( - self, mocker, stage, create_transport, fake_exception, pending_connection_op + self, mocker, stage, arbitrary_exception, pending_connection_op ): - # A connection failure is unexpected if there is not a pending Connect/Reconnect operation + # A connection failure is unexpected if there is not a pending Connect/ReauthorizeConnection operation # i.e. "Why did we get a connection failure? We weren't even trying to connect!" - mock_handler = mocker.patch.object( - unhandled_exceptions, "exception_caught_in_background_thread" - ) + mock_handler = mocker.patch.object(handle_exceptions, "swallow_unraised_exception") stage._pending_connection_operation = pending_connection_op - stage.transport.on_mqtt_connection_failure_handler(fake_exception) + + # Trigger connection failure with arbitrary cause + stage.transport.on_mqtt_connection_failure_handler(arbitrary_exception) + + # swallow exception handler has been called assert mock_handler.call_count == 1 - assert mock_handler.call_args[0][0] is fake_exception - - -@pytest.mark.describe("MQTTTransportStage - EVENT: MQTT disconnected") -class TestMQTTProviderOnDisconnected(object): - @pytest.mark.it("Calls self.on_disconnected when the transport disconnected event fires") - @pytest.mark.parametrize( - "cause", - [pytest.param(None, id="No error cause"), pytest.param(Exception(), id="With error cause")], - ) - @pytest.mark.parametrize( - "pending_connection_op", - [ - pytest.param(None, id="No pending operation"), - pytest.param(pipeline_ops_base.ConnectOperation(), id="Pending ConnectOperation"), - pytest.param(pipeline_ops_base.ReconnectOperation(), id="Pending ReconnectOperation"), - pytest.param(pipeline_ops_base.DisconnectOperation(), id="Pending DisconnectOperation"), - ], - ) - def test_disconnected_handler(self, stage, create_transport, pending_connection_op, cause): - stage._pending_connection_op = pending_connection_op - assert stage.previous.on_disconnected.call_count == 0 - stage.transport.on_mqtt_disconnected_handler(cause) - assert stage.previous.on_disconnected.call_count == 1 - - @pytest.mark.it( - "Completes a pending DisconnectOperation with success when the transport disconnected event fires without an error cause" - ) - def test_compltetes_pending_disconnect_op_when_no_error(self, mocker, stage, create_transport): - op = pipeline_ops_base.DisconnectOperation(callback=mocker.MagicMock()) - stage.run_op(op) - assert op.callback.call_count == 0 - assert stage._pending_connection_op is op - stage.transport.on_mqtt_disconnected_handler(None) - assert_callback_succeeded(op=op) - assert stage._pending_connection_op is None - - @pytest.mark.it( - "Completes a pending DisconnectOperation with failure (from ConnectionDroppedError) when the transport disconnected event fires with an error cause" - ) - def test_completes_pending_disconnect_op_with_error( - self, mocker, stage, create_transport, fake_exception - ): - op = pipeline_ops_base.DisconnectOperation(callback=mocker.MagicMock()) - stage.run_op(op) - assert op.callback.call_count == 0 - assert stage._pending_connection_op is op - stage.transport.on_mqtt_disconnected_handler(fake_exception) - assert_callback_failed(op=op, error=errors.ConnectionDroppedError) - assert stage._pending_connection_op is None - if six.PY3: - assert op.error.__cause__ is fake_exception - - @pytest.mark.it( - "Ignores an unrelated pending operation when the transport disconnected event fires" - ) - @pytest.mark.parametrize( - "cause", - [pytest.param(None, id="No error cause"), pytest.param(Exception(), id="With error cause")], - ) - @pytest.mark.parametrize( - "pending_connection_op", - [ - pytest.param(pipeline_ops_base.ConnectOperation(), id="Pending ConnectOperation"), - pytest.param(pipeline_ops_base.ReconnectOperation(), id="Pending ReconnectOperation"), - ], - ) - def test_ignores_unrelated_op( - self, mocker, stage, create_transport, pending_connection_op, cause - ): - stage._pending_connection_op = pending_connection_op - stage.transport.on_mqtt_disconnected_handler(cause) - # The unrelated pending operation is STILL the pending connection op - assert stage._pending_connection_op is pending_connection_op - - @pytest.mark.it( - "Triggers the unhandled exception handler (with ConnectionDroppedError) when the disconnect is unexpected" - ) - @pytest.mark.parametrize( - "cause", - [pytest.param(None, id="No error cause"), pytest.param(Exception(), id="With error cause")], - ) - @pytest.mark.parametrize( - "pending_connection_op", - [ - pytest.param(None, id="No pending operation"), - pytest.param(pipeline_ops_base.ConnectOperation(), id="Pending ConnectOperation"), - pytest.param(pipeline_ops_base.ReconnectOperation(), id="Pending ReconnectOperation"), - ], - ) - def test_unexpected_disconnect( - self, mocker, stage, create_transport, pending_connection_op, cause - ): - # A disconnect is unexpected when there is no pending operation, or a pending, non-Disconnect operation - mock_handler = mocker.patch.object( - unhandled_exceptions, "exception_caught_in_background_thread" + assert mock_handler.call_args == mocker.call( + arbitrary_exception, log_msg=mocker.ANY, log_lvl="info" ) + + +@pytest.mark.describe("MQTTTransportStage - OCCURANCE: MQTT disconnected") +class TestMQTTTransportStageOnDisconnected(MQTTTransportStageTestConfigComplex): + @pytest.fixture(params=[False, True], ids=["No error cause", "With error cause"]) + def cause(self, request, arbitrary_exception): + if request.param: + return arbitrary_exception + else: + return None + + @pytest.mark.it("Sends a DisconnectedEvent up the pipeline") + @pytest.mark.parametrize( + "pending_connection_op", + [ + pytest.param(None, id="No pending operation"), + pytest.param( + pipeline_ops_base.ConnectOperation(callback=fake_callback), + id="Pending ConnectOperation", + ), + pytest.param( + pipeline_ops_base.ReauthorizeConnectionOperation(callback=fake_callback), + id="Pending ReauthorizeConnectionOperation", + ), + pytest.param( + pipeline_ops_base.DisconnectOperation(callback=fake_callback), + id="Pending DisconnectOperation", + ), + ], + ) + def test_disconnected_handler(self, stage, pending_connection_op, cause): stage._pending_connection_op = pending_connection_op + assert stage.send_event_up.call_count == 0 + + # Trigger disconnect stage.transport.on_mqtt_disconnected_handler(cause) + + assert stage.send_event_up.call_count == 1 + event = stage.send_event_up.call_args[0][0] + assert isinstance(event, pipeline_events_base.DisconnectedEvent) + + @pytest.mark.it("Completes a pending DisconnectOperation successfully") + def test_compltetes_pending_disconnect_op(self, mocker, stage, cause): + # Create a pending DisconnectOperation + op = pipeline_ops_base.DisconnectOperation(callback=mocker.MagicMock()) + stage.run_op(op) + assert not op.completed + assert stage._pending_connection_op is op + + # Trigger disconnect + stage.transport.on_mqtt_disconnected_handler(cause) + + assert op.completed + assert op.error is None + + @pytest.mark.it( + "Swallows the exception that caused the disconnect, if there is a pending DisconnectOperation" + ) + def test_completes_pending_disconnect_op_with_error(self, mocker, stage, arbitrary_exception): + mock_swallow = mocker.patch.object(handle_exceptions, "swallow_unraised_exception") + + # Create a pending DisconnectOperation + op = pipeline_ops_base.DisconnectOperation(callback=mocker.MagicMock()) + stage.run_op(op) + assert not op.completed + assert stage._pending_connection_op is op + + # Trigger disconnect with arbitrary cause + stage.transport.on_mqtt_disconnected_handler(arbitrary_exception) + + # Exception swallower was called + assert mock_swallow.call_count == 1 + assert mock_swallow.call_args == mocker.call(arbitrary_exception, log_msg=mocker.ANY) + + @pytest.mark.it( + "Completes (unsuccessfully) a pending operation that is NOT a DisconnectOperation, with the cause of the disconnection set as the error, if there is a cause provided" + ) + @pytest.mark.parametrize( + "pending_connection_op", + [ + pytest.param( + pipeline_ops_base.ConnectOperation(callback=fake_callback), + id="Pending ConnectOperation", + ), + pytest.param( + pipeline_ops_base.ReauthorizeConnectionOperation(callback=fake_callback), + id="Pending ReauthorizeConnectionOperation", + ), + ], + ) + def test_comletes_with_cause_as_error_if_cause( + self, mocker, stage, pending_connection_op, arbitrary_exception + ): + stage._pending_connection_op = pending_connection_op + assert not pending_connection_op.completed + + # Trigger disconnect with arbitrary cause + stage.transport.on_mqtt_disconnected_handler(arbitrary_exception) + + assert pending_connection_op.completed + assert pending_connection_op.error is arbitrary_exception + + @pytest.mark.it( + "Completes (unsuccessfully) a pending operation that is NOT a DisconnectOperation with a ConnectionDroppedError if no cause is provided for the disconnection" + ) + @pytest.mark.parametrize( + "pending_connection_op", + [ + pytest.param( + pipeline_ops_base.ConnectOperation(callback=fake_callback), + id="Pending ConnectOperation", + ), + pytest.param( + pipeline_ops_base.ReauthorizeConnectionOperation(callback=fake_callback), + id="Pending ReauthorizeConnectionOperation", + ), + ], + ) + def test_comletes_with_connection_dropped_error_as_error_if_no_cause( + self, mocker, stage, pending_connection_op, arbitrary_exception + ): + stage._pending_connection_op = pending_connection_op + assert not pending_connection_op.completed + + # Trigger disconnect with no cause + stage.transport.on_mqtt_disconnected_handler() + + assert pending_connection_op.completed + assert isinstance(pending_connection_op.error, transport_exceptions.ConnectionDroppedError) + + @pytest.mark.it( + "Sends the error to the swallowed exception handler, if there is no pending operation when a disconnection occurs" + ) + def test_no_pending_op(self, mocker, stage, cause): + mock_handler = mocker.patch.object(handle_exceptions, "swallow_unraised_exception") + assert stage._pending_connection_op is None + + # Trigger disconnect + stage.transport.on_mqtt_disconnected_handler(cause) + assert mock_handler.call_count == 1 - assert isinstance(mock_handler.call_args[0][0], errors.ConnectionDroppedError) - if six.PY3: - assert mock_handler.call_args[0][0].__cause__ is cause + exception = mock_handler.call_args[0][0] + assert exception.__cause__ is cause + + @pytest.mark.it("Clears any pending operation on the stage") + @pytest.mark.parametrize( + "pending_connection_op", + [ + pytest.param(None, id="No pending operation"), + pytest.param( + pipeline_ops_base.ConnectOperation(callback=fake_callback), + id="Pending ConnectOperation", + ), + pytest.param( + pipeline_ops_base.ReauthorizeConnectionOperation(callback=fake_callback), + id="Pending ReauthorizeConnectionOperation", + ), + pytest.param( + pipeline_ops_base.DisconnectOperation(callback=fake_callback), + id="Pending DisconnectOperation", + ), + ], + ) + def test_clears_pending(self, mocker, stage, pending_connection_op, cause): + stage._pending_connection_op = pending_connection_op + + # Trigger disconnect + stage.transport.on_mqtt_disconnected_handler(cause) + + assert stage._pending_connection_op is None diff --git a/azure-iot-device/tests/common/test_async_adapter.py b/azure-iot-device/tests/common/test_async_adapter.py index 762d62e6b..800b85d42 100644 --- a/azure-iot-device/tests/common/test_async_adapter.py +++ b/azure-iot-device/tests/common/test_async_adapter.py @@ -119,27 +119,27 @@ class TestAwaitableCallback(object): @pytest.mark.it( "Causes an error to be set on the instance Future when an error parameter is passed to the call (without return_arg_name)" ) - async def test_raises_error_without_return_arg_name(self, fake_error): + async def test_raises_error_without_return_arg_name(self, arbitrary_exception): callback = async_adapter.AwaitableCallback() assert not callback.future.done() - callback(error=fake_error) + callback(error=arbitrary_exception) await asyncio.sleep(0.1) # wait to give time to complete the callback assert callback.future.done() - assert callback.future.exception() == fake_error - with pytest.raises(fake_error.__class__) as e_info: + assert callback.future.exception() == arbitrary_exception + with pytest.raises(arbitrary_exception.__class__) as e_info: await callback.completion() - assert e_info.value is fake_error + assert e_info.value is arbitrary_exception @pytest.mark.it( "Causes an error to be set on the instance Future when an error parameter is passed to the call (with return_arg_name)" ) - async def test_raises_error_with_return_arg_name(self, fake_error): + async def test_raises_error_with_return_arg_name(self, arbitrary_exception): callback = async_adapter.AwaitableCallback(return_arg_name="arg_name") assert not callback.future.done() - callback(error=fake_error) + callback(error=arbitrary_exception) await asyncio.sleep(0.1) # wait to give time to complete the callback assert callback.future.done() - assert callback.future.exception() == fake_error - with pytest.raises(fake_error.__class__) as e_info: + assert callback.future.exception() == arbitrary_exception + with pytest.raises(arbitrary_exception.__class__) as e_info: await callback.completion() - assert e_info.value is fake_error + assert e_info.value is arbitrary_exception diff --git a/azure-iot-device/tests/common/test_evented_callback.py b/azure-iot-device/tests/common/test_evented_callback.py index 73e36c3f0..fb4a0761d 100644 --- a/azure-iot-device/tests/common/test_evented_callback.py +++ b/azure-iot-device/tests/common/test_evented_callback.py @@ -64,27 +64,27 @@ class TestEventedCallback(object): @pytest.mark.it( "Causes an error to be raised from the wait call when an error parameter is passed to the call (without return_arg_name)" ) - def test_raises_error_without_return_arg_name(self, fake_error): + def test_raises_error_without_return_arg_name(self, arbitrary_exception): callback = EventedCallback() assert not callback.completion_event.isSet() - callback(error=fake_error) + callback(error=arbitrary_exception) sleep(0.1) # wait to give time to complete the callback assert callback.completion_event.isSet() - assert callback.exception == fake_error - with pytest.raises(fake_error.__class__) as e_info: + assert callback.exception == arbitrary_exception + with pytest.raises(arbitrary_exception.__class__) as e_info: callback.wait_for_completion() - assert e_info.value is fake_error + assert e_info.value is arbitrary_exception @pytest.mark.it( "Causes an error to be raised from the wait call when an error parameter is passed to the call (with return_arg_name)" ) - def test_raises_error_with_return_arg_name(self, fake_error): + def test_raises_error_with_return_arg_name(self, arbitrary_exception): callback = EventedCallback(return_arg_name="arg_name") assert not callback.completion_event.isSet() - callback(error=fake_error) + callback(error=arbitrary_exception) sleep(0.1) # wait to give time to complete the callback assert callback.completion_event.isSet() - assert callback.exception == fake_error - with pytest.raises(fake_error.__class__) as e_info: + assert callback.exception == arbitrary_exception + with pytest.raises(arbitrary_exception.__class__) as e_info: callback.wait_for_completion() - assert e_info.value is fake_error + assert e_info.value is arbitrary_exception diff --git a/azure-iot-device/tests/common/test_http_transport.py b/azure-iot-device/tests/common/test_http_transport.py new file mode 100644 index 000000000..48dff6523 --- /dev/null +++ b/azure-iot-device/tests/common/test_http_transport.py @@ -0,0 +1,284 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import azure.iot.device.common.http_transport as http_transport +from azure.iot.device.common.http_transport import HTTPTransport +from azure.iot.device.common.models.x509 import X509 +from six.moves import http_client +from azure.iot.device.common import transport_exceptions as errors +import pytest +import logging +import ssl +import threading + + +logging.basicConfig(level=logging.DEBUG) + +fake_hostname = "__fake_hostname__" +fake_method = "__fake_method__" +fake_path = "__fake_path__" + + +fake_server_verification_cert = "__fake_server_verification_cert__" +fake_x509_cert = "__fake_x509_certificate__" + + +@pytest.mark.describe("HTTPTransport - Instantiation") +class TestInstantiation(object): + @pytest.mark.it("Sets the proper required instance parameters") + def test_sets_required_parameters(self, mocker): + + mocker.patch.object(ssl, "SSLContext").return_value + mocker.patch.object(HTTPTransport, "_create_ssl_context").return_value + + http_transport_object = HTTPTransport( + hostname=fake_hostname, + server_verification_cert=fake_server_verification_cert, + x509_cert=fake_x509_cert, + ) + + assert http_transport_object._hostname == fake_hostname + assert http_transport_object._server_verification_cert == fake_server_verification_cert + assert http_transport_object._x509_cert == fake_x509_cert + + @pytest.mark.it( + "Configures TLS/SSL context to use TLS 1.2, require certificates and check hostname" + ) + def test_configures_tls_context(self, mocker): + mock_ssl_context_constructor = mocker.patch.object(ssl, "SSLContext") + mock_ssl_context = mock_ssl_context_constructor.return_value + + HTTPTransport(hostname=fake_hostname) + # Verify correctness of TLS/SSL Context + assert mock_ssl_context_constructor.call_count == 1 + assert mock_ssl_context_constructor.call_args == mocker.call(protocol=ssl.PROTOCOL_TLSv1_2) + assert mock_ssl_context.check_hostname is True + assert mock_ssl_context.verify_mode == ssl.CERT_REQUIRED + + @pytest.mark.it( + "Configures TLS/SSL context using default certificates if protocol wrapper not instantiated with a server verification certificate" + ) + def test_configures_tls_context_with_default_certs(self, mocker): + mock_ssl_context = mocker.patch.object(ssl, "SSLContext").return_value + + HTTPTransport(hostname=fake_hostname) + + assert mock_ssl_context.load_default_certs.call_count == 1 + assert mock_ssl_context.load_default_certs.call_args == mocker.call() + + @pytest.mark.it( + "Configures TLS/SSL context with provided server verification certificate if protocol wrapper instantiated with a server verification certificate" + ) + def test_configures_tls_context_with_server_verification_certs(self, mocker): + mock_ssl_context = mocker.patch.object(ssl, "SSLContext").return_value + + HTTPTransport( + hostname=fake_hostname, server_verification_cert=fake_server_verification_cert + ) + + assert mock_ssl_context.load_verify_locations.call_count == 1 + assert mock_ssl_context.load_verify_locations.call_args == mocker.call( + cadata=fake_server_verification_cert + ) + + @pytest.mark.it("Configures TLS/SSL context with client-provided-certificate-chain like x509") + def test_configures_tls_context_with_client_provided_certificate_chain(self, mocker): + fake_client_cert = X509("fantastic_beasts", "where_to_find_them", "alohomora") + mock_ssl_context_constructor = mocker.patch.object(ssl, "SSLContext") + mock_ssl_context = mock_ssl_context_constructor.return_value + + HTTPTransport(hostname=fake_hostname, x509_cert=fake_client_cert) + + assert mock_ssl_context.load_default_certs.call_count == 1 + assert mock_ssl_context.load_cert_chain.call_count == 1 + assert mock_ssl_context.load_cert_chain.call_args == mocker.call( + fake_client_cert.certificate_file, + fake_client_cert.key_file, + fake_client_cert.pass_phrase, + ) + + +class HTTPTransportTestConfig(object): + @pytest.fixture + def mock_http_client_constructor(self, mocker): + mocker.patch.object(ssl, "SSLContext").return_value + mocker.patch.object(HTTPTransport, "_create_ssl_context").return_value + mock_client_constructor = mocker.patch.object(http_client, "HTTPSConnection", autospec=True) + mock_client = mock_client_constructor.return_value + response_value = mock_client.getresponse.return_value + response_value.status = 1234 + response_value.reason = "__fake_reason__" + response_value.read.return_value = "__fake_response_read_value__" + return mock_client_constructor + + +@pytest.mark.describe("HTTPTransport - .request()") +class TestRequest(HTTPTransportTestConfig): + @pytest.mark.it("Generates a unique HTTP Client connection for each request") + def test_creates_http_connection_object(self, mocker, mock_http_client_constructor): + transport = HTTPTransport(hostname=fake_hostname) + # We call .result because we need to block for the Future to complete before moving on. + transport.request(fake_method, fake_path, mocker.MagicMock()).result() + assert mock_http_client_constructor.call_count == 1 + + transport.request(fake_method, fake_path, mocker.MagicMock()).result() + assert mock_http_client_constructor.call_count == 2 + + @pytest.mark.it("Uses the HTTP Transport SSL Context.") + def test_uses_ssl_context(self, mocker, mock_http_client_constructor): + transport = HTTPTransport(hostname=fake_hostname) + done = transport.request(fake_method, fake_path, mocker.MagicMock()) + done.result() + + assert mock_http_client_constructor.call_count == 1 + assert mock_http_client_constructor.call_args[1]["context"] == transport._ssl_context + + @pytest.mark.it("Formats the request URL from the stage's hostname, given a path.") + def test_formats_http_client_request_with_only_method_and_path( + self, mocker, mock_http_client_constructor + ): + transport = HTTPTransport(hostname=fake_hostname) + mock_http_client_request = mock_http_client_constructor.return_value.request + fake_method = "__fake_method__" + fake_path = "__fake_path__" + expected_url = "https://{}/{}".format(fake_hostname, fake_path) + done = transport.request(fake_method, fake_path, mocker.MagicMock()) + done.result() + + assert mock_http_client_constructor.call_count == 1 + assert mock_http_client_request.call_count == 1 + assert mock_http_client_request.call_args == mocker.call( + fake_method, expected_url, body="", headers={} + ) + + @pytest.mark.it( + "Formats the request URL from the stage's hostname, given a path and query parameters." + ) + def test_formats_http_client_request_with_method_path_and_query_params( + self, mocker, mock_http_client_constructor + ): + transport = HTTPTransport(hostname=fake_hostname) + mock_http_client_request = mock_http_client_constructor.return_value.request + fake_method = "__fake_method__" + fake_path = "__fake_path__" + fake_query_params = "__fake_query_params__" + expected_url = "https://{}/{}?{}".format(fake_hostname, fake_path, fake_query_params) + + done = transport.request( + fake_method, fake_path, mocker.MagicMock(), query_params=fake_query_params + ) + done.result() + assert mock_http_client_constructor.call_count == 1 + assert mock_http_client_request.call_count == 1 + assert mock_http_client_request.call_args == mocker.call( + fake_method, expected_url, body="", headers={} + ) + + @pytest.mark.it("Sends HTTP request via HTTP Client.") + @pytest.mark.parametrize( + "method, path, query_params, body, headers", + [ + pytest.param( + "__fake_method__", + "__fake_path__", + None, + None, + None, + id="Method and path (optional params set to None)", + ), + pytest.param( + "__fake_method__", + "__fake_path__", + "", + "", + "", + id="Method and path (optional params set to empty string)", + ), + pytest.param( + "__fake_method__", + "__fake_path__", + "__fake_query_params__", + None, + None, + id="Method, path, and query params (body and headers set to None)", + ), + pytest.param( + "__fake_method__", + "__fake_path__", + "__fake_query_params__", + "__fake_body__", + None, + id="Method, path, query_params, and body (headers set to None)", + ), + pytest.param( + "__fake_method__", + "__fake_path__", + "__fake_query_params__", + "__fake_body__", + "__fake_headers__", + id="All parameters provided", + ), + ], + ) + def test_calls_http_client_request_with_given_parameters( + self, mocker, mock_http_client_constructor, method, path, query_params, body, headers + ): + transport = HTTPTransport(hostname=fake_hostname) + mock_http_client_request = mock_http_client_constructor.return_value.request + if query_params: + expected_url = "https://{}/{}?{}".format(fake_hostname, path, query_params) + else: + expected_url = "https://{}/{}".format(fake_hostname, path) + + cb = mocker.MagicMock() + done = transport.request( + method, path, cb, body=body, headers=headers, query_params=query_params + ) + done.result() + assert mock_http_client_constructor.call_count == 1 + assert mock_http_client_request.call_count == 1 + assert mock_http_client_request.call_args[0][0] == method + assert mock_http_client_request.call_args[0][1] == expected_url + + actual_body = mock_http_client_request.call_args[1]["body"] + actual_headers = mock_http_client_request.call_args[1]["headers"] + + if body: + assert actual_body == body + else: + assert not bool(actual_body) + if headers: + assert actual_headers == headers + else: + assert not bool(actual_headers) + + @pytest.mark.it( + "Creates a response object with a status code, reason, and unformatted HTTP response and returns it via the callback." + ) + def test_returns_response_on_success(self, mocker, mock_http_client_constructor): + transport = HTTPTransport(hostname=fake_hostname) + cb = mocker.MagicMock() + done = transport.request(fake_method, fake_path, cb) + done.result() + + assert mock_http_client_constructor.call_count == 1 + assert cb.call_count == 1 + assert cb.call_args[1]["response"]["status_code"] == 1234 + assert cb.call_args[1]["response"]["reason"] == "__fake_reason__" + assert cb.call_args[1]["response"]["resp"] == "__fake_response_read_value__" + + @pytest.mark.it("Raises a ProtocolClientError if request raises an unexpected Exception") + def test_client_raises_unexpected_error( + self, mocker, mock_http_client_constructor, arbitrary_exception + ): + transport = HTTPTransport(hostname=fake_hostname) + mock_http_client_constructor.return_value.connect.side_effect = arbitrary_exception + cb = mocker.MagicMock() + done = transport.request(fake_method, fake_path, cb) + done.result() + error = cb.call_args[1]["error"] + assert isinstance(error, errors.ProtocolClientError) + assert error.__cause__ is arbitrary_exception diff --git a/azure-iot-device/tests/common/test_mqtt_transport.py b/azure-iot-device/tests/common/test_mqtt_transport.py index 378289731..22dcf5c64 100644 --- a/azure-iot-device/tests/common/test_mqtt_transport.py +++ b/azure-iot-device/tests/common/test_mqtt_transport.py @@ -4,14 +4,17 @@ # license information. # -------------------------------------------------------------------------- +import azure.iot.device.common.mqtt_transport as mqtt_transport from azure.iot.device.common.mqtt_transport import MQTTTransport, OperationManager from azure.iot.device.common.models.x509 import X509 -from azure.iot.device.common import errors +from azure.iot.device.common import transport_exceptions as errors import paho.mqtt.client as mqtt import ssl import copy import pytest import logging +import socket +import socks logging.basicConfig(level=logging.DEBUG) @@ -22,22 +25,19 @@ fake_username = fake_hostname + "/" + fake_device_id new_fake_password = "new fake password" fake_topic = "fake_topic" fake_payload = "Tarantallegra" +fake_cipher = "DHE-RSA-AES128-SHA" fake_qos = 1 fake_mid = 52 fake_rc = 0 -failed_conack_rc = mqtt.CONNACK_REFUSED_IDENTIFIER_REJECTED +fake_success_rc = 0 +fake_failed_rc = mqtt.MQTT_ERR_PROTOCOL +failed_connack_rc = mqtt.CONNACK_REFUSED_IDENTIFIER_REJECTED +fake_keepalive = 1234 +fake_thread = "__fake_thread__" -class DummyException(Exception): - pass - - -class UnhandledException(BaseException): - pass - - -# mapping of Paho conack rc codes to Error object classes -conack_return_codes = [ +# mapping of Paho connack rc codes to Error object classes +connack_return_codes = [ { "name": "CONNACK_REFUSED_PROTOCOL_VERSION", "rc": mqtt.CONNACK_REFUSED_PROTOCOL_VERSION, @@ -68,13 +68,13 @@ conack_return_codes = [ # mapping of Paho rc codes to Error object classes operation_return_codes = [ - {"name": "MQTT_ERR_NOMEM", "rc": mqtt.MQTT_ERR_NOMEM, "error": errors.ProtocolClientError}, + {"name": "MQTT_ERR_NOMEM", "rc": mqtt.MQTT_ERR_NOMEM, "error": errors.ConnectionDroppedError}, { "name": "MQTT_ERR_PROTOCOL", "rc": mqtt.MQTT_ERR_PROTOCOL, "error": errors.ProtocolClientError, }, - {"name": "MQTT_ERR_INVAL", "rc": mqtt.MQTT_ERR_INVAL, "error": errors.ArgumentError}, + {"name": "MQTT_ERR_INVAL", "rc": mqtt.MQTT_ERR_INVAL, "error": errors.ProtocolClientError}, { "name": "MQTT_ERR_NO_CONN", "rc": mqtt.MQTT_ERR_NO_CONN, @@ -154,6 +154,32 @@ class TestInstantiation(object): client_id=fake_device_id, clean_session=False, protocol=mqtt.MQTTv311 ) + @pytest.mark.it( + "Creates an instance of the Paho MQTT Client using Websockets when websockets parameter is True" + ) + def test_configures_mqtt_websockets(self, mocker): + mock_mqtt_client_constructor = mocker.patch.object(mqtt, "Client") + mock_mqtt_client = mock_mqtt_client_constructor.return_value + + MQTTTransport( + client_id=fake_device_id, + hostname=fake_hostname, + username=fake_username, + websockets=True, + ) + + assert mock_mqtt_client_constructor.call_count == 1 + assert mock_mqtt_client_constructor.call_args == mocker.call( + client_id=fake_device_id, + clean_session=False, + protocol=mqtt.MQTTv311, + transport="websockets", + ) + + # Verify websockets options have been set + assert mock_mqtt_client.ws_set_options.call_count == 1 + assert mock_mqtt_client.ws_set_options.call_args == mocker.call(path="/$iothub/websocket") + @pytest.mark.it( "Configures TLS/SSL context to use TLS 1.2, require certificates and check hostname" ) @@ -175,7 +201,7 @@ class TestInstantiation(object): assert mock_mqtt_client.tls_set_context.call_args == mocker.call(context=mock_ssl_context) @pytest.mark.it( - "Configures TLS/SSL context using default certificates if protocol wrapper not instantiated with a CA certificate" + "Configures TLS/SSL context using default certificates if protocol wrapper not instantiated with a server verification certificate" ) def test_configures_tls_context_with_default_certs(self, mocker, mock_mqtt_client): mock_ssl_context_constructor = mocker.patch.object(ssl, "SSLContext") @@ -187,22 +213,41 @@ class TestInstantiation(object): assert mock_ssl_context.load_default_certs.call_args == mocker.call() @pytest.mark.it( - "Configures TLS/SSL context with provided CA certificates if protocol wrapper instantiated with a CA certificate" + "Configures TLS/SSL context with provided server verification certificate if protocol wrapper instantiated with a server verification certificate" ) - def test_configures_tls_context_with_ca_certs(self, mocker, mock_mqtt_client): + def test_configures_tls_context_with_server_verification_certs(self, mocker, mock_mqtt_client): mock_ssl_context_constructor = mocker.patch.object(ssl, "SSLContext") mock_ssl_context = mock_ssl_context_constructor.return_value - ca_cert = "dummy_certificate" + server_verification_cert = "dummy_certificate" MQTTTransport( client_id=fake_device_id, hostname=fake_hostname, username=fake_username, - ca_cert=ca_cert, + server_verification_cert=server_verification_cert, ) assert mock_ssl_context.load_verify_locations.call_count == 1 - assert mock_ssl_context.load_verify_locations.call_args == mocker.call(cadata=ca_cert) + assert mock_ssl_context.load_verify_locations.call_args == mocker.call( + cadata=server_verification_cert + ) + + @pytest.mark.it( + "Configures TLS/SSL context with provided cipher if present during instantiation" + ) + def test_confgures_tls_context_with_cipher(self, mocker, mock_mqtt_client): + mock_ssl_context_constructor = mocker.patch.object(ssl, "SSLContext") + mock_ssl_context = mock_ssl_context_constructor.return_value + + MQTTTransport( + client_id=fake_device_id, + hostname=fake_hostname, + username=fake_username, + cipher=fake_cipher, + ) + + assert mock_ssl_context.set_ciphers.call_count == 1 + assert mock_ssl_context.set_ciphers.call_args == mocker.call(fake_cipher) @pytest.mark.it("Configures TLS/SSL context with client-provided-certificate-chain like x509") def test_configures_tls_context_with_client_provided_certificate_chain( @@ -291,11 +336,49 @@ class TestConnect(object): pytest.param(None, id="No password provided"), ], ) - def test_calls_paho_connect(self, mocker, mock_mqtt_client, transport, password): + @pytest.mark.parametrize( + "websockets,port", + [ + pytest.param(False, 8883, id="Not using websockets"), + pytest.param(True, 443, id="Using websockets"), + ], + ) + def test_calls_paho_connect( + self, mocker, mock_mqtt_client, transport, password, websockets, port + ): + + # We don't want to use a special fixture for websockets, so instead we are overriding the attribute below. + # However, we want to assert that this value is not undefined. For instance, the self._websockets convention private attribute + # could be changed to self._websockets1, and all our tests would still pass without the below assert statement. + assert transport._websockets is False + + transport._websockets = websockets + transport.connect(password) assert mock_mqtt_client.connect.call_count == 1 - assert mock_mqtt_client.connect.call_args == mocker.call(host=fake_hostname, port=8883) + assert mock_mqtt_client.connect.call_args == mocker.call( + host=fake_hostname, port=port, keepalive=mocker.ANY + ) + + @pytest.mark.it("Passes DEFAULT_KEEPALIVE to paho connect function") + @pytest.mark.parametrize( + "password", + [ + pytest.param(fake_password, id="Password provided"), + pytest.param(None, id="No password provided"), + ], + ) + def test_calls_paho_connect_with_keepalive(self, mocker, mock_mqtt_client, transport, password): + + mqtt_transport.DEFAULT_KEEPALIVE = fake_keepalive + + transport.connect(password) + + assert mock_mqtt_client.connect.call_count == 1 + assert mock_mqtt_client.connect.call_args == mocker.call( + host=fake_hostname, port=8883, keepalive=fake_keepalive + ) @pytest.mark.it("Starts MQTT Network Loop") @pytest.mark.parametrize( @@ -311,13 +394,88 @@ class TestConnect(object): assert mock_mqtt_client.loop_start.call_count == 1 assert mock_mqtt_client.loop_start.call_args == mocker.call() + @pytest.mark.it("Raises a ProtocolClientError if Paho connect raises an unexpected Exception") + def test_client_raises_unexpected_error( + self, mocker, mock_mqtt_client, transport, arbitrary_exception + ): + mock_mqtt_client.connect.side_effect = arbitrary_exception + with pytest.raises(errors.ProtocolClientError) as e_info: + transport.connect(fake_password) + assert e_info.value.__cause__ is arbitrary_exception + + @pytest.mark.it( + "Raises a ConnectionFailedError if Paho connect raises a socket.error Exception" + ) + def test_client_raises_socket_error( + self, mocker, mock_mqtt_client, transport, arbitrary_exception + ): + socket_error = socket.error() + mock_mqtt_client.connect.side_effect = socket_error + with pytest.raises(errors.ConnectionFailedError) as e_info: + transport.connect(fake_password) + assert e_info.value.__cause__ is socket_error + + @pytest.mark.it( + "Raises a TlsExchangeAuthError if Paho connect raises a socket.error of type SSLCertVerificationError Exception" + ) + def test_client_raises_socket_tls_auth_error( + self, mocker, mock_mqtt_client, transport, arbitrary_exception + ): + socket_error = ssl.SSLError("socket error", "CERTIFICATE_VERIFY_FAILED") + mock_mqtt_client.connect.side_effect = socket_error + with pytest.raises(errors.TlsExchangeAuthError) as e_info: + transport.connect(fake_password) + assert e_info.value.__cause__ is socket_error + print(e_info.value.__cause__.strerror) + + @pytest.mark.it( + "Raises a ProtocolProxyError if Paho connect raises a socket error or a ProxyError exception" + ) + def test_client_raises_socket_error_or_proxy_error_as_proxy_error( + self, mocker, mock_mqtt_client, transport, arbitrary_exception + ): + socks_error = socks.SOCKS5Error( + "it is a sock 5 error", socket_err="a general SOCKS5Error error" + ) + mock_mqtt_client.connect.side_effect = socks_error + with pytest.raises(errors.ProtocolProxyError) as e_info: + transport.connect(fake_password) + assert e_info.value.__cause__ is socks_error + print(e_info.value.__cause__.strerror) + + @pytest.mark.it( + "Raises a UnauthorizedError if Paho connect raises a socket error or a ProxyError exception" + ) + def test_client_raises_socket_error_or_proxy_error_as_unauthorized_error( + self, mocker, mock_mqtt_client, transport, arbitrary_exception + ): + socks_error = socks.SOCKS5AuthError( + "it is a sock 5 auth error", socket_err="an auth SOCKS5Error error" + ) + mock_mqtt_client.connect.side_effect = socks_error + with pytest.raises(errors.UnauthorizedError) as e_info: + transport.connect(fake_password) + assert e_info.value.__cause__ is socks_error + print(e_info.value.__cause__.strerror) + + @pytest.mark.it("Allows any BaseExceptions raised in Paho connect to propagate") + def test_client_raises_base_exception( + self, mock_mqtt_client, transport, arbitrary_base_exception + ): + mock_mqtt_client.connect.side_effect = arbitrary_base_exception + with pytest.raises(arbitrary_base_exception.__class__) as e_info: + transport.connect(fake_password) + assert e_info.value is arbitrary_base_exception + + # NOTE: this test tests for all possible return codes, even ones that shouldn't be + # possible on a connect operation. + @pytest.mark.it("Raises a custom Exception if Paho connect returns a failing rc code") @pytest.mark.parametrize( "error_params", operation_return_codes, ids=["{}->{}".format(x["name"], x["error"].__name__) for x in operation_return_codes], ) - @pytest.mark.it("Raises a custom Exception if connect returns a failing rc code") - def test_transport_returns_failing_rc_code( + def test_client_returns_failing_rc_code( self, mocker, mock_mqtt_client, transport, error_params ): mock_mqtt_client.connect.return_value = error_params["rc"] @@ -325,11 +483,11 @@ class TestConnect(object): transport.connect(fake_password) -@pytest.mark.describe("MQTTTransport - .reconnect()") -class TestReconnect(object): +@pytest.mark.describe("MQTTTransport - .reauthorize_connection()") +class TestReauthorizeConnection(object): @pytest.mark.it("Uses the stored username and provided password for Paho credentials") def test_use_provided_password(self, mocker, mock_mqtt_client, transport): - transport.reconnect(fake_password) + transport.reauthorize_connection(fake_password) assert mock_mqtt_client.username_pw_set.call_count == 1 assert mock_mqtt_client.username_pw_set.call_args == mocker.call( @@ -340,7 +498,7 @@ class TestReconnect(object): "Uses the stored username without a password for Paho credentials, if password is not provided" ) def test_use_no_password(self, mocker, mock_mqtt_client, transport): - transport.reconnect() + transport.reauthorize_connection() assert mock_mqtt_client.username_pw_set.call_count == 1 assert mock_mqtt_client.username_pw_set.call_args == mocker.call( @@ -356,26 +514,46 @@ class TestReconnect(object): ], ) def test_calls_paho_reconnect(self, mocker, mock_mqtt_client, transport, password): - transport.reconnect(password) + transport.reauthorize_connection(password) assert mock_mqtt_client.reconnect.call_count == 1 assert mock_mqtt_client.reconnect.call_args == mocker.call() + @pytest.mark.it("Raises a ProtocolClientError if Paho reconnect raises an unexpected Exception") + def test_client_raises_unexpected_error( + self, mocker, mock_mqtt_client, transport, arbitrary_exception + ): + mock_mqtt_client.reconnect.side_effect = arbitrary_exception + with pytest.raises(errors.ProtocolClientError) as e_info: + transport.reauthorize_connection(fake_password) + assert e_info.value.__cause__ is arbitrary_exception + + @pytest.mark.it("Allows any BaseExceptions raised in Paho reconnect to propagate") + def test_client_raises_base_exception( + self, mock_mqtt_client, transport, arbitrary_base_exception + ): + mock_mqtt_client.reconnect.side_effect = arbitrary_base_exception + with pytest.raises(arbitrary_base_exception.__class__) as e_info: + transport.reauthorize_connection(fake_password) + assert e_info.value is arbitrary_base_exception + + # NOTE: this test tests for all possible return codes, even ones that shouldn't be + # possible on a reconnect operation. + @pytest.mark.it("Raises a custom Exception if Paho reconnect returns a failing rc code") @pytest.mark.parametrize( "error_params", operation_return_codes, ids=["{}->{}".format(x["name"], x["error"].__name__) for x in operation_return_codes], ) - @pytest.mark.it("Raises a custom Exception if reconnect returns a failing rc code") - def test_transport_returns_failing_rc_code( + def test_client_returns_failing_rc_code( self, mocker, mock_mqtt_client, transport, error_params ): mock_mqtt_client.reconnect.return_value = error_params["rc"] with pytest.raises(error_params["error"]): - transport.reconnect(fake_password) + transport.reauthorize_connection(fake_password) -@pytest.mark.describe("MQTTTransport - EVENT: Connect Completed") +@pytest.mark.describe("MQTTTransport - OCCURANCE: Connect Completed") class TestEventConnectComplete(object): @pytest.mark.it( "Triggers on_mqtt_connected_handler event handler upon successful connect completion" @@ -405,8 +583,10 @@ class TestEventConnectComplete(object): # Not raising an exception == test passed @pytest.mark.it("Recovers from Exception in on_mqtt_connected_handler event handler") - def test_event_handler_callback_raises_exception(self, mocker, mock_mqtt_client, transport): - event_cb = mocker.MagicMock(side_effect=DummyException) + def test_event_handler_callback_raises_exception( + self, mocker, mock_mqtt_client, transport, arbitrary_exception + ): + event_cb = mocker.MagicMock(side_effect=arbitrary_exception) transport.on_mqtt_connected_handler = event_cb transport.connect(fake_password) @@ -419,24 +599,25 @@ class TestEventConnectComplete(object): "Allows any BaseExceptions raised in on_mqtt_connected_handler event handler to propagate" ) def test_event_handler_callback_raises_base_exception( - self, mocker, mock_mqtt_client, transport + self, mocker, mock_mqtt_client, transport, arbitrary_base_exception ): - event_cb = mocker.MagicMock(side_effect=UnhandledException) + event_cb = mocker.MagicMock(side_effect=arbitrary_base_exception) transport.on_mqtt_connected_handler = event_cb transport.connect(fake_password) - with pytest.raises(UnhandledException): + with pytest.raises(arbitrary_base_exception.__class__) as e_info: mock_mqtt_client.on_connect( client=mock_mqtt_client, userdata=None, flags=None, rc=fake_rc ) + assert e_info.value is arbitrary_base_exception -@pytest.mark.describe("MQTTTransport - EVENT: Connection Failure") +@pytest.mark.describe("MQTTTransport - OCCURANCE: Connection Failure") class TestEventConnectionFailure(object): @pytest.mark.parametrize( "error_params", - conack_return_codes, - ids=["{}->{}".format(x["name"], x["error"].__name__) for x in conack_return_codes], + connack_return_codes, + ids=["{}->{}".format(x["name"], x["error"].__name__) for x in connack_return_codes], ) @pytest.mark.it( "Triggers on_mqtt_connection_failure_handler event handler with custom Exception upon failed connect completion" @@ -468,20 +649,22 @@ class TestEventConnectionFailure(object): transport.connect(fake_password) mock_mqtt_client.on_connect( - client=mock_mqtt_client, userdata=None, flags=None, rc=failed_conack_rc + client=mock_mqtt_client, userdata=None, flags=None, rc=failed_connack_rc ) # No further asserts required - this is a test to show that it skips a callback. # Not raising an exception == test passed @pytest.mark.it("Recovers from Exception in on_mqtt_connection_failure_handler event handler") - def test_event_handler_callback_raises_exception(self, mocker, mock_mqtt_client, transport): - event_cb = mocker.MagicMock(side_effect=DummyException) + def test_event_handler_callback_raises_exception( + self, mocker, mock_mqtt_client, transport, arbitrary_exception + ): + event_cb = mocker.MagicMock(side_effect=arbitrary_exception) transport.on_mqtt_connection_failure_handler = event_cb transport.connect(fake_password) mock_mqtt_client.on_connect( - client=mock_mqtt_client, userdata=None, flags=None, rc=failed_conack_rc + client=mock_mqtt_client, userdata=None, flags=None, rc=failed_connack_rc ) # Callback was called, but exception did not propagate @@ -491,16 +674,17 @@ class TestEventConnectionFailure(object): "Allows any BaseExceptions raised in on_mqtt_connection_failure_handler event handler to propagate" ) def test_event_handler_callback_raises_base_exception( - self, mocker, mock_mqtt_client, transport + self, mocker, mock_mqtt_client, transport, arbitrary_base_exception ): - event_cb = mocker.MagicMock(side_effect=UnhandledException) + event_cb = mocker.MagicMock(side_effect=arbitrary_base_exception) transport.on_mqtt_connection_failure_handler = event_cb transport.connect(fake_password) - with pytest.raises(UnhandledException): + with pytest.raises(arbitrary_base_exception.__class__) as e_info: mock_mqtt_client.on_connect( - client=mock_mqtt_client, userdata=None, flags=None, rc=failed_conack_rc + client=mock_mqtt_client, userdata=None, flags=None, rc=failed_connack_rc ) + assert e_info.value is arbitrary_base_exception @pytest.mark.describe("MQTTTransport - .disconnect()") @@ -519,21 +703,64 @@ class TestDisconnect(object): assert mock_mqtt_client.loop_stop.call_count == 1 assert mock_mqtt_client.loop_stop.call_args == mocker.call() + @pytest.mark.it( + "Raises a ProtocolClientError if Paho disconnect raises an unexpected Exception" + ) + def test_client_raises_unexpected_error( + self, mocker, mock_mqtt_client, transport, arbitrary_exception + ): + mock_mqtt_client.disconnect.side_effect = arbitrary_exception + with pytest.raises(errors.ProtocolClientError) as e_info: + transport.disconnect() + assert e_info.value.__cause__ is arbitrary_exception + + @pytest.mark.it("Allows any BaseExceptions raised in Paho disconnect to propagate") + def test_client_raises_base_exception( + self, mock_mqtt_client, transport, arbitrary_base_exception + ): + mock_mqtt_client.disconnect.side_effect = arbitrary_base_exception + with pytest.raises(arbitrary_base_exception.__class__) as e_info: + transport.disconnect() + assert e_info.value is arbitrary_base_exception + + # NOTE: this test tests for most possible return codes, even ones that shouldn't be + # possible on a disconnect operation. The exception is codes that correspond to a + # ConnectionDroppedError, as that does not result in a failure for .disconnect() + @pytest.mark.it("Raises a custom Exception if Paho disconnect returns a failing rc code") @pytest.mark.parametrize( "error_params", - operation_return_codes, - ids=["{}->{}".format(x["name"], x["error"].__name__) for x in operation_return_codes], + [x for x in operation_return_codes if x["error"] is not errors.ConnectionDroppedError], + ids=[ + "{}->{}".format(x["name"], x["error"].__name__) + for x in operation_return_codes + if x["error"] is not errors.ConnectionDroppedError + ], ) - @pytest.mark.it("Raises a custom Exception if disconnect returns a failing rc code") - def test_transport_returns_failing_rc_code( + def test_client_returns_failing_rc_code( self, mocker, mock_mqtt_client, transport, error_params ): mock_mqtt_client.disconnect.return_value = error_params["rc"] with pytest.raises(error_params["error"]): transport.disconnect() + # NOTE: Because .disconnect() intends to disconnect the connection, if the connection drops + # it isn't really a failure + @pytest.mark.it("Swallows failing rc codes related to dropped connections") + @pytest.mark.parametrize( + "error_params", + [x for x in operation_return_codes if x["error"] is errors.ConnectionDroppedError], + ids=[ + x["name"] for x in operation_return_codes if x["error"] is errors.ConnectionDroppedError + ], + ) + def test_client_drops_connection(self, mock_mqtt_client, transport, error_params): + mock_mqtt_client.disconnect.return_value = error_params["rc"] + transport.disconnect() -@pytest.mark.describe("MQTTTransport - EVENT: Disconnect Completed") + # No assert required - not throwing an error -> success! + + +@pytest.mark.describe("MQTTTransport - OCCURANCE: Disconnect Completed") class TestEventDisconnectCompleted(object): @pytest.mark.it( "Triggers on_mqtt_disconnected_handler event handler upon disconnect completion" @@ -594,8 +821,10 @@ class TestEventDisconnectCompleted(object): # Not raising an exception == test passed @pytest.mark.it("Recovers from Exception in on_mqtt_disconnected_handler event handler") - def test_event_handler_callback_raises_exception(self, mocker, mock_mqtt_client, transport): - event_cb = mocker.MagicMock(side_effect=DummyException) + def test_event_handler_callback_raises_exception( + self, mocker, mock_mqtt_client, transport, arbitrary_exception + ): + event_cb = mocker.MagicMock(side_effect=arbitrary_exception) transport.on_mqtt_disconnected_handler = event_cb transport.disconnect() @@ -608,14 +837,91 @@ class TestEventDisconnectCompleted(object): "Allows any BaseExceptions raised in on_mqtt_disconnected_handler event handler to propagate" ) def test_event_handler_callback_raises_base_exception( - self, mocker, mock_mqtt_client, transport + self, mocker, mock_mqtt_client, transport, arbitrary_base_exception ): - event_cb = mocker.MagicMock(side_effect=UnhandledException) + event_cb = mocker.MagicMock(side_effect=arbitrary_base_exception) transport.on_mqtt_disconnected_handler = event_cb transport.disconnect() - with pytest.raises(UnhandledException): + with pytest.raises(arbitrary_base_exception.__class__) as e_info: mock_mqtt_client.on_disconnect(client=mock_mqtt_client, userdata=None, rc=fake_rc) + assert e_info.value is arbitrary_base_exception + + @pytest.mark.it("Calls Paho's disconnect() method if cause is not None") + def test_calls_disconnect_with_cause(self, mock_mqtt_client, transport): + mock_mqtt_client.on_disconnect(client=mock_mqtt_client, userdata=None, rc=fake_failed_rc) + assert mock_mqtt_client.disconnect.call_count == 1 + + @pytest.mark.it("Does not call Paho's disconnect() method if cause is None") + def test_doesnt_call_disconnect_without_cause(self, mock_mqtt_client, transport): + mock_mqtt_client.on_disconnect(client=mock_mqtt_client, userdata=None, rc=fake_success_rc) + assert mock_mqtt_client.disconnect.call_count == 0 + + @pytest.mark.it("Calls Paho's loop_stop() if cause is not None") + def test_calls_loop_stop(self, mock_mqtt_client, transport): + mock_mqtt_client.on_disconnect(client=mock_mqtt_client, userdata=None, rc=fake_failed_rc) + assert mock_mqtt_client.loop_stop.call_count == 1 + + @pytest.mark.it("Does not calls Paho's loop_stop() if cause is None") + def test_does_not_call_loop_stop(self, mock_mqtt_client, transport): + mock_mqtt_client.on_disconnect(client=mock_mqtt_client, userdata=None, rc=fake_success_rc) + assert mock_mqtt_client.loop_stop.call_count == 0 + + @pytest.mark.it("Sets Paho's _thread to None if cause is not None") + def test_sets_thread_to_none(self, mock_mqtt_client, transport): + mock_mqtt_client._thread = fake_thread + mock_mqtt_client.on_disconnect(client=mock_mqtt_client, userdata=None, rc=fake_failed_rc) + assert mock_mqtt_client._thread is None + + @pytest.mark.it("Does not sets Paho's _thread to None if cause is None") + def test_does_not_set_thread_to_none(self, mock_mqtt_client, transport): + mock_mqtt_client._thread = fake_thread + mock_mqtt_client.on_disconnect(client=mock_mqtt_client, userdata=None, rc=fake_success_rc) + assert mock_mqtt_client._thread == fake_thread + + @pytest.mark.it("Allows any Exception raised by Paho's disconnect() to propagate") + def test_disconnect_raises_exception( + self, mock_mqtt_client, transport, mocker, arbitrary_exception + ): + mock_mqtt_client.disconnect = mocker.MagicMock(side_effect=arbitrary_exception) + with pytest.raises(type(arbitrary_exception)) as e_info: + mock_mqtt_client.on_disconnect( + client=mock_mqtt_client, userdata=None, rc=fake_failed_rc + ) + assert e_info.value is arbitrary_exception + + @pytest.mark.it("Allows any BaseException raised by Paho's disconnect() to propagate") + def test_disconnect_raises_base_exception( + self, mock_mqtt_client, transport, mocker, arbitrary_base_exception + ): + mock_mqtt_client.disconnect = mocker.MagicMock(side_effect=arbitrary_base_exception) + with pytest.raises(type(arbitrary_base_exception)) as e_info: + mock_mqtt_client.on_disconnect( + client=mock_mqtt_client, userdata=None, rc=fake_failed_rc + ) + assert e_info.value is arbitrary_base_exception + + @pytest.mark.it("Allows any Exception raised by Paho's loop_stop() to propagate") + def test_loop_stop_raises_exception( + self, mock_mqtt_client, transport, mocker, arbitrary_exception + ): + mock_mqtt_client.loop_stop = mocker.MagicMock(side_effect=arbitrary_exception) + with pytest.raises(type(arbitrary_exception)) as e_info: + mock_mqtt_client.on_disconnect( + client=mock_mqtt_client, userdata=None, rc=fake_failed_rc + ) + assert e_info.value is arbitrary_exception + + @pytest.mark.it("Allows any BaseException raised by Paho's loop_stop() to propagate") + def test_loop_stop_raises_base_exception( + self, mock_mqtt_client, transport, mocker, arbitrary_base_exception + ): + mock_mqtt_client.loop_stop = mocker.MagicMock(side_effect=arbitrary_base_exception) + with pytest.raises(type(arbitrary_base_exception)) as e_info: + mock_mqtt_client.on_disconnect( + client=mock_mqtt_client, userdata=None, rc=fake_failed_rc + ) + assert e_info.value is arbitrary_base_exception @pytest.mark.describe("MQTTTransport - .subscribe()") @@ -786,8 +1092,10 @@ class TestSubscribe(object): assert callback3.call_count == 1 @pytest.mark.it("Recovers from Exception in callback") - def test_callback_raises_exception(self, mocker, mock_mqtt_client, transport): - callback = mocker.MagicMock(side_effect=DummyException) + def test_callback_raises_exception( + self, mocker, mock_mqtt_client, transport, arbitrary_exception + ): + callback = mocker.MagicMock(side_effect=arbitrary_exception) mock_mqtt_client.subscribe.return_value = (fake_rc, fake_mid) transport.subscribe(topic=fake_topic, qos=fake_qos, callback=callback) @@ -799,21 +1107,24 @@ class TestSubscribe(object): assert callback.call_count == 1 @pytest.mark.it("Allows any BaseExceptions raised in callback to propagate") - def test_callback_raises_base_exception(self, mocker, mock_mqtt_client, transport): - callback = mocker.MagicMock(side_effect=UnhandledException) + def test_callback_raises_base_exception( + self, mocker, mock_mqtt_client, transport, arbitrary_base_exception + ): + callback = mocker.MagicMock(side_effect=arbitrary_base_exception) mock_mqtt_client.subscribe.return_value = (fake_rc, fake_mid) transport.subscribe(topic=fake_topic, qos=fake_qos, callback=callback) - with pytest.raises(UnhandledException): + with pytest.raises(arbitrary_base_exception.__class__) as e_info: mock_mqtt_client.on_subscribe( client=mock_mqtt_client, userdata=None, mid=fake_mid, granted_qos=fake_qos ) + assert e_info.value is arbitrary_base_exception @pytest.mark.it("Recovers from Exception in callback when Paho event handler triggered early") def test_callback_rasies_exception_when_paho_on_subscribe_triggered_early( - self, mocker, mock_mqtt_client, transport + self, mocker, mock_mqtt_client, transport, arbitrary_exception ): - callback = mocker.MagicMock(side_effect=DummyException) + callback = mocker.MagicMock(side_effect=arbitrary_exception) def trigger_early_on_subscribe(topic, qos): mock_mqtt_client.on_subscribe( @@ -837,9 +1148,9 @@ class TestSubscribe(object): "Allows any BaseExceptions raised in callback when Paho event handler triggered early to propagate" ) def test_callback_raises_base_exception_when_paho_on_subscribe_triggered_early( - self, mocker, mock_mqtt_client, transport + self, mocker, mock_mqtt_client, transport, arbitrary_base_exception ): - callback = mocker.MagicMock(side_effect=UnhandledException) + callback = mocker.MagicMock(side_effect=arbitrary_base_exception) def trigger_early_on_subscribe(topic, qos): mock_mqtt_client.on_subscribe( @@ -854,16 +1165,37 @@ class TestSubscribe(object): mock_mqtt_client.subscribe.side_effect = trigger_early_on_subscribe # Initiate subscribe - with pytest.raises(UnhandledException): + with pytest.raises(arbitrary_base_exception.__class__) as e_info: transport.subscribe(topic=fake_topic, qos=fake_qos, callback=callback) + assert e_info.value is arbitrary_base_exception + @pytest.mark.it("Raises a ProtocolClientError if Paho subscribe raises an unexpected Exception") + def test_client_raises_unexpected_error( + self, mocker, mock_mqtt_client, transport, arbitrary_exception + ): + mock_mqtt_client.subscribe.side_effect = arbitrary_exception + with pytest.raises(errors.ProtocolClientError) as e_info: + transport.subscribe(topic=fake_topic, qos=fake_qos, callback=None) + assert e_info.value.__cause__ is arbitrary_exception + + @pytest.mark.it("Allows any BaseExceptions raised in Paho subscribe to propagate") + def test_client_raises_base_exception( + self, mock_mqtt_client, transport, arbitrary_base_exception + ): + mock_mqtt_client.subscribe.side_effect = arbitrary_base_exception + with pytest.raises(arbitrary_base_exception.__class__) as e_info: + transport.subscribe(topic=fake_topic, qos=fake_qos, callback=None) + assert e_info.value is arbitrary_base_exception + + # NOTE: this test tests for all possible return codes, even ones that shouldn't be + # possible on a subscribe operation. + @pytest.mark.it("Raises a custom Exception if Paho subscribe returns a failing rc code") @pytest.mark.parametrize( "error_params", operation_return_codes, ids=["{}->{}".format(x["name"], x["error"].__name__) for x in operation_return_codes], ) - @pytest.mark.it("Raises a custom Exception if subscribe returns a failing rc code") - def test_transport_returns_failing_rc_code( + def test_client_returns_failing_rc_code( self, mocker, mock_mqtt_client, transport, error_params ): mock_mqtt_client.subscribe.return_value = (error_params["rc"], 0) @@ -1017,8 +1349,10 @@ class TestUnsubscribe(object): assert callback3.call_count == 1 @pytest.mark.it("Recovers from Exception in callback") - def test_callback_raises_exception(self, mocker, mock_mqtt_client, transport): - callback = mocker.MagicMock(side_effect=DummyException) + def test_callback_raises_exception( + self, mocker, mock_mqtt_client, transport, arbitrary_exception + ): + callback = mocker.MagicMock(side_effect=arbitrary_exception) mock_mqtt_client.unsubscribe.return_value = (fake_rc, fake_mid) transport.unsubscribe(topic=fake_topic, callback=callback) @@ -1028,19 +1362,22 @@ class TestUnsubscribe(object): assert callback.call_count == 1 @pytest.mark.it("Allows any BaseExceptions raised in callback to propagate") - def test_callback_raises_base_exception(self, mocker, mock_mqtt_client, transport): - callback = mocker.MagicMock(side_effect=UnhandledException) + def test_callback_raises_base_exception( + self, mocker, mock_mqtt_client, transport, arbitrary_base_exception + ): + callback = mocker.MagicMock(side_effect=arbitrary_base_exception) mock_mqtt_client.unsubscribe.return_value = (fake_rc, fake_mid) transport.unsubscribe(topic=fake_topic, callback=callback) - with pytest.raises(UnhandledException): + with pytest.raises(arbitrary_base_exception.__class__) as e_info: mock_mqtt_client.on_unsubscribe(client=mock_mqtt_client, userdata=None, mid=fake_mid) + assert e_info.value is arbitrary_base_exception @pytest.mark.it("Recovers from Exception in callback when Paho event handler triggered early") def test_callback_rasies_exception_when_paho_on_unsubscribe_triggered_early( - self, mocker, mock_mqtt_client, transport + self, mocker, mock_mqtt_client, transport, arbitrary_exception ): - callback = mocker.MagicMock(side_effect=DummyException) + callback = mocker.MagicMock(side_effect=arbitrary_exception) def trigger_early_on_unsubscribe(topic): mock_mqtt_client.on_unsubscribe(client=mock_mqtt_client, userdata=None, mid=fake_mid) @@ -1062,9 +1399,9 @@ class TestUnsubscribe(object): "Allows any BaseExceptions raised in callback when Paho event handler triggered early to propagate" ) def test_callback_rasies_base_exception_when_paho_on_unsubscribe_triggered_early( - self, mocker, mock_mqtt_client, transport + self, mocker, mock_mqtt_client, transport, arbitrary_base_exception ): - callback = mocker.MagicMock(side_effect=UnhandledException) + callback = mocker.MagicMock(side_effect=arbitrary_base_exception) def trigger_early_on_unsubscribe(topic): mock_mqtt_client.on_unsubscribe(client=mock_mqtt_client, userdata=None, mid=fake_mid) @@ -1077,16 +1414,39 @@ class TestUnsubscribe(object): mock_mqtt_client.unsubscribe.side_effect = trigger_early_on_unsubscribe # Initiate unsubscribe - with pytest.raises(UnhandledException): + with pytest.raises(arbitrary_base_exception.__class__) as e_info: transport.unsubscribe(topic=fake_topic, callback=callback) + assert e_info.value is arbitrary_base_exception + @pytest.mark.it( + "Raises a ProtocolClientError if Paho unsubscribe raises an unexpected Exception" + ) + def test_client_raises_unexpected_error( + self, mocker, mock_mqtt_client, transport, arbitrary_exception + ): + mock_mqtt_client.unsubscribe.side_effect = arbitrary_exception + with pytest.raises(errors.ProtocolClientError) as e_info: + transport.unsubscribe(topic=fake_topic, callback=None) + assert e_info.value.__cause__ is arbitrary_exception + + @pytest.mark.it("Allows any BaseExceptions raised in Paho unsubscribe to propagate") + def test_client_raises_base_exception( + self, mock_mqtt_client, transport, arbitrary_base_exception + ): + mock_mqtt_client.unsubscribe.side_effect = arbitrary_base_exception + with pytest.raises(arbitrary_base_exception.__class__) as e_info: + transport.unsubscribe(topic=fake_topic, callback=None) + assert e_info.value is arbitrary_base_exception + + # NOTE: this test tests for all possible return codes, even ones that shouldn't be + # possible on an unsubscribe operation. + @pytest.mark.it("Raises a custom Exception if Paho unsubscribe returns a failing rc code") @pytest.mark.parametrize( "error_params", operation_return_codes, ids=["{}->{}".format(x["name"], x["error"].__name__) for x in operation_return_codes], ) - @pytest.mark.it("Raises a custom Exception if unsubscribe returns a failing rc code") - def test_transport_returns_failing_rc_code( + def test_client_returns_failing_rc_code( self, mocker, mock_mqtt_client, transport, error_params ): mock_mqtt_client.unsubscribe.return_value = (error_params["rc"], 0) @@ -1142,7 +1502,7 @@ class TestPublish(object): with pytest.raises(ValueError): transport.publish(topic=topic, payload=fake_payload, qos=fake_qos) - @pytest.mark.it("Raises ValueError on invalid payload") + @pytest.mark.it("Raises ValueError on invalid payload value") @pytest.mark.parametrize("payload", [str(b"0" * 268435456)], ids=["Payload > 268435455 bytes"]) def test_raises_value_error_invalid_payload(self, payload): # Manually instantiate protocol wrapper, do NOT mock paho client (paho generates this error) @@ -1152,6 +1512,23 @@ class TestPublish(object): with pytest.raises(ValueError): transport.publish(topic=fake_topic, payload=payload, qos=fake_qos) + @pytest.mark.it("Raises TypeError on invalid payload type") + @pytest.mark.parametrize( + "payload", + [ + pytest.param({"a": "b"}, id="Dictionary"), + pytest.param([1, 2, 3], id="List"), + pytest.param(object(), id="Object"), + ], + ) + def test_raises_type_error_invalid_payload_type(self, payload): + # Manually instantiate protocol wrapper, do NOT mock paho client (paho generates this error) + transport = MQTTTransport( + client_id=fake_device_id, hostname=fake_hostname, username=fake_username + ) + with pytest.raises(TypeError): + transport.publish(topic=fake_topic, payload=payload, qos=fake_qos) + @pytest.mark.it("Triggers callback upon publish completion") def test_triggers_callback_upon_paho_on_publish_event( self, mocker, mock_mqtt_client, transport, message_info @@ -1283,8 +1660,10 @@ class TestPublish(object): assert callback3.call_count == 1 @pytest.mark.it("Recovers from Exception in callback") - def test_callback_raises_exception(self, mocker, mock_mqtt_client, transport, message_info): - callback = mocker.MagicMock(side_effect=DummyException) + def test_callback_raises_exception( + self, mocker, mock_mqtt_client, transport, message_info, arbitrary_exception + ): + callback = mocker.MagicMock(side_effect=arbitrary_exception) mock_mqtt_client.publish.return_value = message_info transport.publish(topic=fake_topic, payload=fake_payload, callback=callback) @@ -1295,22 +1674,23 @@ class TestPublish(object): @pytest.mark.it("Allows any BaseExceptions raised in callback to propagate") def test_callback_raises_base_exception( - self, mocker, mock_mqtt_client, transport, message_info + self, mocker, mock_mqtt_client, transport, message_info, arbitrary_base_exception ): - callback = mocker.MagicMock(side_effect=UnhandledException) + callback = mocker.MagicMock(side_effect=arbitrary_base_exception) mock_mqtt_client.publish.return_value = message_info transport.publish(topic=fake_topic, payload=fake_payload, callback=callback) - with pytest.raises(UnhandledException): + with pytest.raises(arbitrary_base_exception.__class__) as e_info: mock_mqtt_client.on_publish( client=mock_mqtt_client, userdata=None, mid=message_info.mid ) + assert e_info.value is arbitrary_base_exception @pytest.mark.it("Recovers from Exception in callback when Paho event handler triggered early") def test_callback_rasies_exception_when_paho_on_publish_triggered_early( - self, mocker, mock_mqtt_client, transport, message_info + self, mocker, mock_mqtt_client, transport, message_info, arbitrary_exception ): - callback = mocker.MagicMock(side_effect=DummyException) + callback = mocker.MagicMock(side_effect=arbitrary_exception) def trigger_early_on_publish(topic, payload, qos): mock_mqtt_client.on_publish( @@ -1334,9 +1714,9 @@ class TestPublish(object): "Allows any BaseExceptions raised in callback when Paho event handler triggered early to propagate" ) def test_callback_rasies_base_exception_when_paho_on_publish_triggered_early( - self, mocker, mock_mqtt_client, transport, message_info + self, mocker, mock_mqtt_client, transport, message_info, arbitrary_base_exception ): - callback = mocker.MagicMock(side_effect=UnhandledException) + callback = mocker.MagicMock(side_effect=arbitrary_base_exception) def trigger_early_on_publish(topic, payload, qos): mock_mqtt_client.on_publish( @@ -1351,16 +1731,37 @@ class TestPublish(object): mock_mqtt_client.publish.side_effect = trigger_early_on_publish # Initiate publish - with pytest.raises(UnhandledException): + with pytest.raises(arbitrary_base_exception.__class__) as e_info: transport.publish(topic=fake_topic, payload=fake_payload, callback=callback) + assert e_info.value is arbitrary_base_exception + @pytest.mark.it("Raises a ProtocolClientError if Paho publish raises an unexpected Exception") + def test_client_raises_unexpected_error( + self, mocker, mock_mqtt_client, transport, arbitrary_exception + ): + mock_mqtt_client.publish.side_effect = arbitrary_exception + with pytest.raises(errors.ProtocolClientError) as e_info: + transport.publish(topic=fake_topic, payload=fake_payload, callback=None) + assert e_info.value.__cause__ is arbitrary_exception + + @pytest.mark.it("Allows any BaseExceptions raised in Paho publish to propagate") + def test_client_raises_base_exception( + self, mock_mqtt_client, transport, arbitrary_base_exception + ): + mock_mqtt_client.publish.side_effect = arbitrary_base_exception + with pytest.raises(arbitrary_base_exception.__class__) as e_info: + transport.publish(topic=fake_topic, payload=fake_payload, callback=None) + assert e_info.value is arbitrary_base_exception + + # NOTE: this test tests for all possible return codes, even ones that shouldn't be + # possible on a publish operation. + @pytest.mark.it("Raises a custom Exception if Paho publish returns a failing rc code") @pytest.mark.parametrize( "error_params", operation_return_codes, ids=["{}->{}".format(x["name"], x["error"].__name__) for x in operation_return_codes], ) - @pytest.mark.it("Raises a custom Exception if publish returns a failing rc code") - def test_transport_returns_failing_rc_code( + def test_client_returns_failing_rc_code( self, mocker, mock_mqtt_client, transport, error_params ): mock_mqtt_client.publish.return_value = (error_params["rc"], 0) @@ -1368,7 +1769,7 @@ class TestPublish(object): transport.publish(topic=fake_topic, payload=fake_payload, callback=None) -@pytest.mark.describe("MQTTTransport - EVENT: Message Received") +@pytest.mark.describe("MQTTTransport - OCCURANCE: Message Received") class TestMessageReceived(object): @pytest.fixture() def message(self): @@ -1405,9 +1806,9 @@ class TestMessageReceived(object): @pytest.mark.it("Recovers from Exception in on_mqtt_message_received_handler event handler") def test_event_handler_callback_raises_exception( - self, mocker, mock_mqtt_client, transport, message + self, mocker, mock_mqtt_client, transport, message, arbitrary_exception ): - event_cb = mocker.MagicMock(side_effect=DummyException) + event_cb = mocker.MagicMock(side_effect=arbitrary_exception) transport.on_mqtt_message_received_handler = event_cb mock_mqtt_client.on_message(client=mock_mqtt_client, userdata=None, mqtt_message=message) @@ -1419,15 +1820,16 @@ class TestMessageReceived(object): "Allows any BaseExceptions raised in on_mqtt_message_received_handler event handler to propagate" ) def test_event_handler_callback_raises_base_exception( - self, mocker, mock_mqtt_client, transport, message + self, mocker, mock_mqtt_client, transport, message, arbitrary_base_exception ): - event_cb = mocker.MagicMock(side_effect=UnhandledException) + event_cb = mocker.MagicMock(side_effect=arbitrary_base_exception) transport.on_mqtt_message_received_handler = event_cb - with pytest.raises(UnhandledException): + with pytest.raises(arbitrary_base_exception.__class__) as e_info: mock_mqtt_client.on_message( client=mock_mqtt_client, userdata=None, mqtt_message=message ) + assert e_info.value is arbitrary_base_exception @pytest.mark.describe("MQTTTransport - Misc.") @@ -1547,10 +1949,10 @@ class TestOperationManagerEstablishOperation(object): assert cb_mock.call_count == 1 @pytest.mark.it("Recovers from Exception thrown in callback") - def test_callback_raises_exception(self, mocker): + def test_callback_raises_exception(self, mocker, arbitrary_exception): manager = OperationManager() mid = 1 - cb_mock = mocker.MagicMock(side_effect=DummyException) + cb_mock = mocker.MagicMock(side_effect=arbitrary_exception) # Cause early completion of an unknown operation manager.complete_operation(mid) @@ -1562,17 +1964,18 @@ class TestOperationManagerEstablishOperation(object): assert cb_mock.call_count == 1 @pytest.mark.it("Allows any BaseExceptions raised in callback to propagate") - def test_callback_raises_base_exception(self, mocker): + def test_callback_raises_base_exception(self, mocker, arbitrary_base_exception): manager = OperationManager() mid = 1 - cb_mock = mocker.MagicMock(side_effect=UnhandledException) + cb_mock = mocker.MagicMock(side_effect=arbitrary_base_exception) # Cause early completion of an unknown operation manager.complete_operation(mid) # Establish operation that was already completed - with pytest.raises(UnhandledException): + with pytest.raises(arbitrary_base_exception.__class__) as e_info: manager.establish_operation(mid, cb_mock) + assert e_info.value is arbitrary_base_exception @pytest.mark.it("Does not trigger the callback until after thread lock has been released") def test_callback_called_after_lock_release(self, mocker): @@ -1640,10 +2043,10 @@ class TestOperationManagerCompleteOperation(object): assert cb_mock.call_count == 1 @pytest.mark.it("Recovers from Exception thrown in callback") - def test_callback_raises_exception(self, mocker): + def test_callback_raises_exception(self, mocker, arbitrary_exception): manager = OperationManager() mid = 1 - cb_mock = mocker.MagicMock(side_effect=DummyException) + cb_mock = mocker.MagicMock(side_effect=arbitrary_exception) manager.establish_operation(mid, cb_mock) assert cb_mock.call_count == 0 @@ -1653,16 +2056,17 @@ class TestOperationManagerCompleteOperation(object): assert cb_mock.call_count == 1 @pytest.mark.it("Allows any BaseExceptions raised in callback to propagate") - def test_callback_raises_base_exception(self, mocker): + def test_callback_raises_base_exception(self, mocker, arbitrary_base_exception): manager = OperationManager() mid = 1 - cb_mock = mocker.MagicMock(side_effect=UnhandledException) + cb_mock = mocker.MagicMock(side_effect=arbitrary_base_exception) manager.establish_operation(mid, cb_mock) assert cb_mock.call_count == 0 - with pytest.raises(UnhandledException): + with pytest.raises(arbitrary_base_exception.__class__) as e_info: manager.complete_operation(mid) + assert e_info.value is arbitrary_base_exception @pytest.mark.it( "Begins tracking an unknown completion if MID does not correspond to a pending operation" diff --git a/azure-iot-device/tests/conftest.py b/azure-iot-device/tests/conftest.py new file mode 100644 index 000000000..e991b8bd5 --- /dev/null +++ b/azure-iot-device/tests/conftest.py @@ -0,0 +1,43 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import pytest + +""" +NOTE: ALL (yes, ALL) tests need some kind of non-specific, arbitrary exception should use +one of the following fixtures. This is to ensure the tests operate correctly - many tests used to +raise Exception or BaseException directly to test arbitrary exceptions, but the result was +that exception handling was hiding other errors (also caught by an "except: Exception" block). + +The solution is to use a subclass of Exception or BaseException that is not defined anywhere else, +thus guaranteeing that it will be unexpected and unhandled except by broad all-encompassing +handling. Furthermore, because the exception in question is derived from either Exception or +BaseException, but is not itself an instance of either, tests checking that the exception in +question is raised will not spuriously pass due to different exceptions being raised. + +For consistency, and to prevent confusion, please do this ONLY by using one of the follwing +fixtures. + +You may (and should!) still use exceptions defined elsewhere for specific, non-arbitrary exceptions +(e.g. testing specific exceptions) +""" + + +@pytest.fixture(scope="function") +def arbitrary_exception(): + class ArbitraryException(Exception): + pass + + e = ArbitraryException() + return e + + +@pytest.fixture(scope="function") +def arbitrary_base_exception(): + class ArbitraryBaseException(BaseException): + pass + + return ArbitraryBaseException() diff --git a/azure-iot-device/tests/iothub/aio/test_async_clients.py b/azure-iot-device/tests/iothub/aio/test_async_clients.py index d54203061..1e5b44ea1 100644 --- a/azure-iot-device/tests/iothub/aio/test_async_clients.py +++ b/azure-iot-device/tests/iothub/aio/test_async_clients.py @@ -11,13 +11,16 @@ import threading import time import os import io +from azure.iot.device import exceptions as client_exceptions from azure.iot.device.iothub.aio import IoTHubDeviceClient, IoTHubModuleClient from azure.iot.device.iothub.pipeline import IoTHubPipeline, constant +from azure.iot.device.iothub.pipeline import exceptions as pipeline_exceptions from azure.iot.device.iothub.models import Message, MethodRequest from azure.iot.device.iothub.aio.async_inbox import AsyncClientInbox from azure.iot.device.common import async_adapter from azure.iot.device.iothub.auth import IoTEdgeError -from azure.iot.device.common.models.x509 import X509 +import sys +from azure.iot.device import constant as device_constant pytestmark = pytest.mark.asyncio logging.basicConfig(level=logging.DEBUG) @@ -29,40 +32,58 @@ async def create_completed_future(result=None): return f -# automatically mock the pipeline for all tests in this file. +# automatically mock the mqtt pipeline for all tests in this file. @pytest.fixture(autouse=True) -def mock_pipeline_init(mocker): +def mock_mqtt_pipeline_init(mocker): return mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipeline") +# automatically mock the http pipeline for all tests in this file. +@pytest.fixture(autouse=True) +def mock_http_pipeline_init(mocker): + return mocker.patch("azure.iot.device.iothub.pipeline.HTTPPipeline") + + class SharedClientInstantiationTests(object): @pytest.mark.it( "Stores the IoTHubPipeline from the 'iothub_pipeline' parameter in the '_iothub_pipeline' attribute" ) - async def test_iothub_pipeline_attribute(self, client_class, iothub_pipeline): - client = client_class(iothub_pipeline) + async def test_iothub_pipeline_attribute(self, client_class, iothub_pipeline, http_pipeline): + client = client_class(iothub_pipeline, http_pipeline) assert client._iothub_pipeline is iothub_pipeline + @pytest.mark.it( + "Stores the HTTPPipeline from the 'http_pipeline' parameter in the '_http_pipeline' attribute" + ) + async def test_sets_http_pipeline_attribute(self, client_class, iothub_pipeline, http_pipeline): + client = client_class(iothub_pipeline, http_pipeline) + + assert client._http_pipeline is http_pipeline + @pytest.mark.it("Sets on_connected handler in the IoTHubPipeline") - async def test_sets_on_connected_handler_in_pipeline(self, client_class, iothub_pipeline): - client = client_class(iothub_pipeline) + async def test_sets_on_connected_handler_in_pipeline( + self, client_class, iothub_pipeline, http_pipeline + ): + client = client_class(iothub_pipeline, http_pipeline) assert client._iothub_pipeline.on_connected is not None assert client._iothub_pipeline.on_connected == client._on_connected @pytest.mark.it("Sets on_disconnected handler in the IoTHubPipeline") - async def test_sets_on_disconnected_handler_in_pipeline(self, client_class, iothub_pipeline): - client = client_class(iothub_pipeline) + async def test_sets_on_disconnected_handler_in_pipeline( + self, client_class, iothub_pipeline, http_pipeline + ): + client = client_class(iothub_pipeline, http_pipeline) assert client._iothub_pipeline.on_disconnected is not None assert client._iothub_pipeline.on_disconnected == client._on_disconnected @pytest.mark.it("Sets on_method_request_received handler in the IoTHubPipeline") async def test_sets_on_method_request_received_handler_in_pipleline( - self, client_class, iothub_pipeline + self, client_class, iothub_pipeline, http_pipeline ): - client = client_class(iothub_pipeline) + client = client_class(iothub_pipeline, http_pipeline) assert client._iothub_pipeline.on_method_request_received is not None assert ( @@ -71,90 +92,191 @@ class SharedClientInstantiationTests(object): ) -class SharedClientCreateFromConnectionStringTests(object): +class SharedClientCreateMethodUserOptionTests(object): + # In these tests we patch the entire 'auth' library instead of specific auth providers in order + # to make them more generic, and applicable across all creation methods. + + @pytest.fixture + def option_test_required_patching(self, mocker): + """Override this fixture in a subclass if unique patching is required""" + pass + @pytest.mark.it( - "Uses the connection string and CA certificate combination to create a SymmetricKeyAuthenticationProvider" + "Sets the 'product_info' user option parameter on the PipelineConfig, if provided" ) - @pytest.mark.parametrize( - "ca_cert", - [ - pytest.param(None, id="No CA certificate"), - pytest.param("some-certificate", id="With CA certificate"), - ], + async def test_product_info_option( + self, + option_test_required_patching, + client_create_method, + create_method_args, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + product_info = "MyProductInfo" + client_create_method(*create_method_args, product_info=product_info) + + # Get configuration object, and ensure it was used for both protocol pipelines + assert mock_mqtt_pipeline_init.call_count == 1 + config = mock_mqtt_pipeline_init.call_args[0][1] + assert config == mock_http_pipeline_init.call_args[0][1] + + assert config.product_info == product_info + + @pytest.mark.it( + "Sets the 'websockets' user option parameter on the PipelineConfig, if provided" ) - async def test_auth_provider_creation(self, mocker, client_class, connection_string, ca_cert): + async def test_websockets_option( + self, + option_test_required_patching, + client_create_method, + create_method_args, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + client_create_method(*create_method_args, websockets=True) + + # Get configuration object, and ensure it was used for both protocol pipelines + assert mock_mqtt_pipeline_init.call_count == 1 + config = mock_mqtt_pipeline_init.call_args[0][1] + assert config == mock_http_pipeline_init.call_args[0][1] + + assert config.websockets + + @pytest.mark.it("Sets the 'cipher' user option parameter on the PipelineConfig, if provided") + async def test_cipher_option( + self, + option_test_required_patching, + client_create_method, + create_method_args, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + cipher = "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256" + client_create_method(*create_method_args, cipher=cipher) + + # Get configuration object, and ensure it was used for both protocol pipelines + assert mock_mqtt_pipeline_init.call_count == 1 + config = mock_mqtt_pipeline_init.call_args[0][1] + assert config == mock_http_pipeline_init.call_args[0][1] + + assert config.cipher == cipher + + @pytest.mark.it( + "Sets the 'server_verification_cert' user option parameter on the AuthenticationProvider, if provided" + ) + async def test_server_verification_cert_option( + self, + option_test_required_patching, + client_create_method, + create_method_args, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + server_verification_cert = "fake_server_verification_cert" + client_create_method(*create_method_args, server_verification_cert=server_verification_cert) + + # Get auth provider object, and ensure it was used for both protocol pipelines + auth = mock_mqtt_pipeline_init.call_args[0][0] + assert auth == mock_http_pipeline_init.call_args[0][0] + + assert auth.server_verification_cert == server_verification_cert + + @pytest.mark.it("Raises a TypeError if an invalid user option parameter is provided") + async def test_invalid_option( + self, + option_test_required_patching, + client_create_method, + create_method_args, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + with pytest.raises(TypeError): + client_create_method(*create_method_args, invalid_option="some_value") + + @pytest.mark.it("Sets default user options if none are provided") + async def test_default_options( + self, + option_test_required_patching, + client_create_method, + create_method_args, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + client_create_method(*create_method_args) + + # Get configuration object, and ensure it was used for both protocol pipelines + assert mock_mqtt_pipeline_init.call_count == 1 + config = mock_mqtt_pipeline_init.call_args[0][1] + assert config == mock_http_pipeline_init.call_args[0][1] + + # Get auth provider object, and ensure it was used for both protocol pipelines + auth = mock_mqtt_pipeline_init.call_args[0][0] + assert auth == mock_http_pipeline_init.call_args[0][0] + + assert config.product_info == "" + assert not config.websockets + assert not config.cipher + assert auth.server_verification_cert is None + + +class SharedClientCreateFromConnectionStringTests(object): + @pytest.mark.it("Uses the connection string to create a SymmetricKeyAuthenticationProvider") + async def test_auth_provider_creation(self, mocker, client_class, connection_string): mock_auth_parse = mocker.patch( "azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider" ).parse + mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipelineConfig") - args = (connection_string,) - kwargs = {} - if ca_cert: - kwargs["ca_cert"] = ca_cert - client_class.create_from_connection_string(*args, **kwargs) + client_class.create_from_connection_string(connection_string) assert mock_auth_parse.call_count == 1 assert mock_auth_parse.call_args == mocker.call(connection_string) - assert mock_auth_parse.return_value.ca_cert is ca_cert @pytest.mark.it("Uses the SymmetricKeyAuthenticationProvider to create an IoTHubPipeline") @pytest.mark.parametrize( - "ca_cert", + "server_verification_cert", [ - pytest.param(None, id="No CA certificate"), - pytest.param("some-certificate", id="With CA certificate"), + pytest.param(None, id="No Server Verification Certificate"), + pytest.param("some-certificate", id="With Server Verification Certificate"), ], ) async def test_pipeline_creation( - self, mocker, client_class, connection_string, ca_cert, mock_pipeline_init + self, + mocker, + client_class, + connection_string, + server_verification_cert, + mock_mqtt_pipeline_init, ): mock_auth = mocker.patch( "azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider" ).parse.return_value - args = (connection_string,) - kwargs = {} - if ca_cert: - kwargs["ca_cert"] = ca_cert - client_class.create_from_connection_string(*args, **kwargs) + mock_config_init = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipelineConfig") - assert mock_pipeline_init.call_count == 1 - assert mock_pipeline_init.call_args == mocker.call(mock_auth) + client_class.create_from_connection_string(connection_string) + + assert mock_mqtt_pipeline_init.call_count == 1 + assert mock_mqtt_pipeline_init.call_args == mocker.call( + mock_auth, mock_config_init.return_value + ) @pytest.mark.it("Uses the IoTHubPipeline to instantiate the client") - @pytest.mark.parametrize( - "ca_cert", - [ - pytest.param(None, id="No CA certificate"), - pytest.param("some-certificate", id="With CA certificate"), - ], - ) - async def test_client_instantiation(self, mocker, client_class, connection_string, ca_cert): + async def test_client_instantiation(self, mocker, client_class, connection_string): mock_pipeline = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipeline").return_value + mock_pipeline_http = mocker.patch( + "azure.iot.device.iothub.pipeline.HTTPPipeline" + ).return_value spy_init = mocker.spy(client_class, "__init__") - args = (connection_string,) - kwargs = {} - if ca_cert: - kwargs["ca_cert"] = ca_cert - client_class.create_from_connection_string(*args, **kwargs) + + client_class.create_from_connection_string(connection_string) assert spy_init.call_count == 1 - assert spy_init.call_args == mocker.call(mocker.ANY, mock_pipeline) + assert spy_init.call_args == mocker.call(mocker.ANY, mock_pipeline, mock_pipeline_http) @pytest.mark.it("Returns the instantiated client") - @pytest.mark.parametrize( - "ca_cert", - [ - pytest.param(None, id="No CA certificate"), - pytest.param("some-certificate", id="With CA certificate"), - ], - ) - async def test_returns_client(self, client_class, connection_string, ca_cert): - args = (connection_string,) - kwargs = {} - if ca_cert: - kwargs["ca_cert"] = ca_cert - client = client_class.create_from_connection_string(*args, **kwargs) + async def test_returns_client(self, client_class, connection_string): + client = client_class.create_from_connection_string(connection_string) assert isinstance(client, client_class) @@ -178,65 +300,6 @@ class SharedClientCreateFromConnectionStringTests(object): client_class.create_from_connection_string(bad_cs) -class SharedClientCreateFromSharedAccessSignature(object): - @pytest.mark.it("Uses the SAS token to create a SharedAccessSignatureAuthenticationProvider") - async def test_auth_provider_creation(self, mocker, client_class, sas_token_string): - mock_auth_parse = mocker.patch( - "azure.iot.device.iothub.auth.SharedAccessSignatureAuthenticationProvider" - ).parse - - client_class.create_from_shared_access_signature(sas_token_string) - - assert mock_auth_parse.call_count == 1 - assert mock_auth_parse.call_args == mocker.call(sas_token_string) - - @pytest.mark.it( - "Uses the SharedAccessSignatureAuthenticationProvider to create an IoTHubPipeline" - ) - async def test_pipeline_creation( - self, mocker, client_class, sas_token_string, mock_pipeline_init - ): - mock_auth = mocker.patch( - "azure.iot.device.iothub.auth.SharedAccessSignatureAuthenticationProvider" - ).parse.return_value - - client_class.create_from_shared_access_signature(sas_token_string) - - assert mock_pipeline_init.call_count == 1 - assert mock_pipeline_init.call_args == mocker.call(mock_auth) - - @pytest.mark.it("Uses the IoTHubPipeline to instantiate the client") - async def test_client_instantiation(self, mocker, client_class, sas_token_string): - mock_pipeline = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipeline").return_value - spy_init = mocker.spy(client_class, "__init__") - - client_class.create_from_shared_access_signature(sas_token_string) - - assert spy_init.call_count == 1 - assert spy_init.call_args == mocker.call(mocker.ANY, mock_pipeline) - - @pytest.mark.it("Returns the instantiated client") - async def test_returns_client(self, mocker, client_class, sas_token_string): - client = client_class.create_from_shared_access_signature(sas_token_string) - assert isinstance(client, client_class) - - # TODO: If auth package was refactored to use SasToken class, tests from that - # class would increase the coverage here. - @pytest.mark.it("Raises ValueError when given an invalid SAS token") - @pytest.mark.parametrize( - "bad_sas", - [ - pytest.param(object(), id="Non-string input"), - pytest.param( - "SharedAccessSignature sr=Invalid&sig=Invalid&se=Invalid", id="Malformed SAS token" - ), - ], - ) - async def test_raises_value_error_on_bad_sas_token(self, client_class, bad_sas): - with pytest.raises(ValueError): - client_class.create_from_shared_access_signature(bad_sas) - - class SharedClientConnectTests(object): @pytest.mark.it("Begins a 'connect' pipeline operation") async def test_calls_pipeline_connect(self, client, iothub_pipeline): @@ -255,17 +318,57 @@ class SharedClientConnectTests(object): # Assert callback completion is waited upon assert cb_mock.completion.call_count == 1 - @pytest.mark.it("Raises an error if the `connect` pipeline operation calls back with an error") + @pytest.mark.it( + "Raises a client error if the `connect` pipeline operation calls back with a pipeline error" + ) + @pytest.mark.parametrize( + "pipeline_error,client_error", + [ + pytest.param( + pipeline_exceptions.ConnectionDroppedError, + client_exceptions.ConnectionDroppedError, + id="ConnectionDroppedError->ConnectionDroppedError", + ), + pytest.param( + pipeline_exceptions.ConnectionFailedError, + client_exceptions.ConnectionFailedError, + id="ConnectionFailedError->ConnectionFailedError", + ), + pytest.param( + pipeline_exceptions.UnauthorizedError, + client_exceptions.CredentialError, + id="UnauthorizedError->CredentialError", + ), + pytest.param( + pipeline_exceptions.ProtocolClientError, + client_exceptions.ClientError, + id="ProtocolClientError->ClientError", + ), + pytest.param( + pipeline_exceptions.TlsExchangeAuthError, + client_exceptions.ClientError, + id="TlsExchangeAuthError->ClientError", + ), + pytest.param( + pipeline_exceptions.ProtocolProxyError, + client_exceptions.ClientError, + id="ProtocolProxyError->ClientError", + ), + pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), + ], + ) async def test_raises_error_on_pipeline_op_error( - self, mocker, client, iothub_pipeline, fake_error + self, mocker, client, iothub_pipeline, pipeline_error, client_error ): + my_pipeline_error = pipeline_error() + def fail_connect(callback): - callback(error=fake_error) + callback(error=my_pipeline_error) iothub_pipeline.connect = mocker.MagicMock(side_effect=fail_connect) - with pytest.raises(fake_error.__class__) as e_info: + with pytest.raises(client_error) as e_info: await client.connect() - assert e_info.value is fake_error + assert e_info.value.__cause__ is my_pipeline_error assert iothub_pipeline.connect.call_count == 1 @@ -290,18 +393,31 @@ class SharedClientDisconnectTests(object): assert cb_mock.completion.call_count == 1 @pytest.mark.it( - "Raises an error if the `disconnect` pipeline operation calls back with an error" + "Raises a client error if the `disconnect` pipeline operation calls back with a pipeline error" + ) + @pytest.mark.parametrize( + "pipeline_error,client_error", + [ + pytest.param( + pipeline_exceptions.ProtocolClientError, + client_exceptions.ClientError, + id="ProtocolClientError->ClientError", + ), + pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), + ], ) async def test_raises_error_on_pipeline_op_error( - self, mocker, client, iothub_pipeline, fake_error + self, mocker, client, iothub_pipeline, pipeline_error, client_error ): + my_pipeline_error = pipeline_error() + def fail_disconnect(callback): - callback(error=fake_error) + callback(error=my_pipeline_error) iothub_pipeline.disconnect = mocker.MagicMock(side_effect=fail_disconnect) - with pytest.raises(fake_error.__class__) as e_info: + with pytest.raises(client_error) as e_info: await client.disconnect() - assert e_info.value is fake_error + assert e_info.value.__cause__ is my_pipeline_error assert iothub_pipeline.disconnect.call_count == 1 @@ -337,18 +453,46 @@ class SharedClientSendD2CMessageTests(object): assert cb_mock.completion.call_count == 1 @pytest.mark.it( - "Raises an error if the `send_message` pipeline operation calls back with an error" + "Raises a client error if the `send_message` pipeline operation calls back with a pipeline error" + ) + @pytest.mark.parametrize( + "pipeline_error,client_error", + [ + pytest.param( + pipeline_exceptions.ConnectionDroppedError, + client_exceptions.ConnectionDroppedError, + id="ConnectionDroppedError->ConnectionDroppedError", + ), + pytest.param( + pipeline_exceptions.ConnectionFailedError, + client_exceptions.ConnectionFailedError, + id="ConnectionFailedError->ConnectionFailedError", + ), + pytest.param( + pipeline_exceptions.UnauthorizedError, + client_exceptions.CredentialError, + id="UnauthorizedError->CredentialError", + ), + pytest.param( + pipeline_exceptions.ProtocolClientError, + client_exceptions.ClientError, + id="ProtocolClientError->ClientError", + ), + pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), + ], ) async def test_raises_error_on_pipeline_op_error( - self, mocker, client, iothub_pipeline, message, fake_error + self, mocker, client, iothub_pipeline, message, client_error, pipeline_error ): + my_pipeline_error = pipeline_error() + def fail_send_message(message, callback): - callback(error=fake_error) + callback(error=my_pipeline_error) iothub_pipeline.send_message = mocker.MagicMock(side_effect=fail_send_message) - with pytest.raises(fake_error.__class__) as e_info: + with pytest.raises(client_error) as e_info: await client.send_message(message) - assert e_info.value is fake_error + assert e_info.value.__cause__ is my_pipeline_error assert iothub_pipeline.send_message.call_count == 1 @pytest.mark.it( @@ -374,6 +518,41 @@ class SharedClientSendD2CMessageTests(object): assert isinstance(sent_message, Message) assert sent_message.data == message_input + @pytest.mark.it("Raises error when message data size is greater than 256 KB") + async def test_raises_error_when_message_data_greater_than_256(self, client, iothub_pipeline): + data_input = "serpensortia" * 256000 + message = Message(data_input) + with pytest.raises(ValueError) as e_info: + await client.send_message(message) + assert "256 KB" in e_info.value.args[0] + assert iothub_pipeline.send_message.call_count == 0 + + @pytest.mark.it("Raises error when message size is greater than 256 KB") + async def test_raises_error_when_message_size_greater_than_256(self, client, iothub_pipeline): + data_input = "serpensortia" + message = Message(data_input) + message.custom_properties["spell"] = data_input * 256000 + with pytest.raises(ValueError) as e_info: + await client.send_message(message) + assert "256 KB" in e_info.value.args[0] + assert iothub_pipeline.send_message.call_count == 0 + + @pytest.mark.it("Does not raises error when message data size is equal to 256 KB") + async def test_raises_error_when_message_data_equal_to_256(self, client, iothub_pipeline): + data_input = "a" * 262095 + message = Message(data_input) + # This check was put as message class may undergo the default content type encoding change + # and the above calculation will change. + if message.get_size() != device_constant.TELEMETRY_MESSAGE_SIZE_LIMIT: + assert False + + await client.send_message(message) + + assert iothub_pipeline.send_message.call_count == 1 + sent_message = iothub_pipeline.send_message.call_args[0][0] + assert isinstance(sent_message, Message) + assert sent_message.data == data_input + class SharedClientReceiveMethodRequestTests(object): @pytest.mark.it("Implicitly enables methods feature if not already enabled") @@ -476,20 +655,48 @@ class SharedClientSendMethodResponseTests(object): assert cb_mock.completion.call_count == 1 @pytest.mark.it( - "Raises an error if the `send_method-response` pipeline operation calls back with an error" + "Raises a client error if the `send_method_response` pipeline operation calls back with a pipeline error" + ) + @pytest.mark.parametrize( + "pipeline_error,client_error", + [ + pytest.param( + pipeline_exceptions.ConnectionDroppedError, + client_exceptions.ConnectionDroppedError, + id="ConnectionDroppedError->ConnectionDroppedError", + ), + pytest.param( + pipeline_exceptions.ConnectionFailedError, + client_exceptions.ConnectionFailedError, + id="ConnectionFailedError->ConnectionFailedError", + ), + pytest.param( + pipeline_exceptions.UnauthorizedError, + client_exceptions.CredentialError, + id="UnauthorizedError->CredentialError", + ), + pytest.param( + pipeline_exceptions.ProtocolClientError, + client_exceptions.ClientError, + id="ProtocolClientError->ClientError", + ), + pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), + ], ) async def test_raises_error_on_pipeline_op_error( - self, mocker, client, iothub_pipeline, method_response, fake_error + self, mocker, client, iothub_pipeline, method_response, pipeline_error, client_error ): + my_pipeline_error = pipeline_error() + def fail_send_method_response(response, callback): - callback(error=fake_error) + callback(error=my_pipeline_error) iothub_pipeline.send_method_response = mocker.MagicMock( side_effect=fail_send_method_response ) - with pytest.raises(fake_error.__class__) as e_info: + with pytest.raises(client_error) as e_info: await client.send_method_response(method_response) - assert e_info.value is fake_error + assert e_info.value.__cause__ is my_pipeline_error assert iothub_pipeline.send_method_response.call_count == 1 @@ -543,17 +750,47 @@ class SharedClientGetTwinTests(object): # Assert callback completion is waited upon assert cb_mock.completion.call_count == 1 - @pytest.mark.it("Raises an error if the `get_twin` pipeline operation calls back with an error") + @pytest.mark.it( + "Raises a client error if the `get_twin` pipeline operation calls back with a pipeline error" + ) + @pytest.mark.parametrize( + "pipeline_error,client_error", + [ + pytest.param( + pipeline_exceptions.ConnectionDroppedError, + client_exceptions.ConnectionDroppedError, + id="ConnectionDroppedError->ConnectionDroppedError", + ), + pytest.param( + pipeline_exceptions.ConnectionFailedError, + client_exceptions.ConnectionFailedError, + id="ConnectionFailedError->ConnectionFailedError", + ), + pytest.param( + pipeline_exceptions.UnauthorizedError, + client_exceptions.CredentialError, + id="UnauthorizedError->CredentialError", + ), + pytest.param( + pipeline_exceptions.ProtocolClientError, + client_exceptions.ClientError, + id="ProtocolClientError->ClientError", + ), + pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), + ], + ) async def test_raises_error_on_pipeline_op_error( - self, mocker, client, iothub_pipeline, fake_error + self, mocker, client, iothub_pipeline, pipeline_error, client_error ): + my_pipeline_error = pipeline_error() + def fail_get_twin(callback): - callback(error=fake_error) + callback(error=my_pipeline_error) iothub_pipeline.get_twin = mocker.MagicMock(side_effect=fail_get_twin) - with pytest.raises(fake_error.__class__) as e_info: + with pytest.raises(client_error) as e_info: await client.get_twin() - assert e_info.value is fake_error + assert e_info.value.__cause__ is my_pipeline_error assert iothub_pipeline.get_twin.call_count == 1 @pytest.mark.it("Returns the twin that the pipeline returned") @@ -626,20 +863,48 @@ class SharedClientPatchTwinReportedPropertiesTests(object): assert cb_mock.completion.call_count == 1 @pytest.mark.it( - "Raises an error if the `patch_twin_reported_properties` pipeline operation calls back with an error" + "Raises a client error if the `patch_twin_reported_properties` pipeline operation calls back with a pipeline error" + ) + @pytest.mark.parametrize( + "pipeline_error,client_error", + [ + pytest.param( + pipeline_exceptions.ConnectionDroppedError, + client_exceptions.ConnectionDroppedError, + id="ConnectionDroppedError->ConnectionDroppedError", + ), + pytest.param( + pipeline_exceptions.ConnectionFailedError, + client_exceptions.ConnectionFailedError, + id="ConnectionFailedError->ConnectionFailedError", + ), + pytest.param( + pipeline_exceptions.UnauthorizedError, + client_exceptions.CredentialError, + id="UnauthorizedError->CredentialError", + ), + pytest.param( + pipeline_exceptions.ProtocolClientError, + client_exceptions.ClientError, + id="ProtocolClientError->ClientError", + ), + pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), + ], ) async def test_raises_error_on_pipeline_op_error( - self, mocker, client, iothub_pipeline, twin_patch_reported, fake_error + self, mocker, client, iothub_pipeline, twin_patch_reported, pipeline_error, client_error ): + my_pipeline_error = pipeline_error() + def fail_patch_twin_reported_properties(patch, callback): - callback(error=fake_error) + callback(error=my_pipeline_error) iothub_pipeline.patch_twin_reported_properties = mocker.MagicMock( side_effect=fail_patch_twin_reported_properties ) - with pytest.raises(fake_error.__class__) as e_info: + with pytest.raises(client_error) as e_info: await client.patch_twin_reported_properties(twin_patch_reported) - assert e_info.value is fake_error + assert e_info.value.__cause__ is my_pipeline_error assert iothub_pipeline.patch_twin_reported_properties.call_count == 1 @@ -684,6 +949,20 @@ class SharedClientReceiveTwinDesiredPropertiesPatchTests(object): assert received_patch is twin_patch_desired +class SharedClientPROPERTYConnectedTests(object): + @pytest.mark.it("Cannot be changed") + async def test_read_only(self, client): + with pytest.raises(AttributeError): + client.connected = not client.connected + + @pytest.mark.it("Reflects the value of the root stage property of the same name") + async def test_reflects_pipeline_property(self, client, iothub_pipeline): + iothub_pipeline.connected = True + assert client.connected + iothub_pipeline.connected = False + assert not client.connected + + ################ # DEVICE TESTS # ################ @@ -693,11 +972,11 @@ class IoTHubDeviceClientTestsConfig(object): return IoTHubDeviceClient @pytest.fixture - def client(self, iothub_pipeline): + def client(self, iothub_pipeline, http_pipeline): """This client automatically resolves callbacks sent to the pipeline. It should be used for the majority of tests. """ - return IoTHubDeviceClient(iothub_pipeline) + return IoTHubDeviceClient(iothub_pipeline, http_pipeline) @pytest.fixture def connection_string(self, device_connection_string): @@ -717,9 +996,9 @@ class TestIoTHubDeviceClientInstantiation( ): @pytest.mark.it("Sets on_c2d_message_received handler in the IoTHubPipeline") async def test_sets_on_c2d_message_received_handler_in_pipeline( - self, client_class, iothub_pipeline + self, client_class, iothub_pipeline, http_pipeline ): - client = client_class(iothub_pipeline) + client = client_class(iothub_pipeline, http_pipeline) assert client._iothub_pipeline.on_c2d_message_received is not None assert ( @@ -727,32 +1006,56 @@ class TestIoTHubDeviceClientInstantiation( == client._inbox_manager.route_c2d_message ) - @pytest.mark.it("Sets the '_edge_pipeline' attribute to None") - async def test_edge_pipeline_is_none(self, client_class, iothub_pipeline): - client = client_class(iothub_pipeline) - - assert client._edge_pipeline is None - @pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - .create_from_connection_string()") class TestIoTHubDeviceClientCreateFromConnectionString( - IoTHubDeviceClientTestsConfig, SharedClientCreateFromConnectionStringTests + IoTHubDeviceClientTestsConfig, + SharedClientCreateFromConnectionStringTests, + SharedClientCreateMethodUserOptionTests, ): - pass + @pytest.fixture + def client_create_method(self, client_class): + """Provides the specific create method for use in universal tests""" + return client_class.create_from_connection_string + + @pytest.fixture + def create_method_args(self, connection_string): + """Provides the specific create method args for use in universal tests""" + return [connection_string] -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - .create_from_shared_access_signature()") -class TestIoTHubDeviceClientCreateFromSharedAccessSignature( - IoTHubDeviceClientTestsConfig, SharedClientCreateFromSharedAccessSignature +@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - .create_from_symmetric_key()") +class TestConfigurationCreateIoTHubDeviceClientFromSymmetricKey( + IoTHubDeviceClientTestsConfig, SharedClientCreateMethodUserOptionTests ): - pass + @pytest.fixture + def client_create_method(self, client_class): + """Provides the specific create method for use in universal tests""" + return client_class.create_from_symmetric_key + + @pytest.fixture + def create_method_args(self, symmetric_key, hostname_fixture, device_id_fixture): + """Provides the specific create method args for use in universal tests""" + return [symmetric_key, hostname_fixture, device_id_fixture] @pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - .create_from_x509_certificate()") -class TestIoTHubDeviceClientCreateFromX509Certificate(IoTHubDeviceClientTestsConfig): +class TestIoTHubDeviceClientCreateFromX509Certificate( + IoTHubDeviceClientTestsConfig, SharedClientCreateMethodUserOptionTests +): hostname = "durmstranginstitute.farend" device_id = "MySnitch" + @pytest.fixture + def client_create_method(self, client_class): + """Provides the specific create method for use in universal tests""" + return client_class.create_from_x509_certificate + + @pytest.fixture + def create_method_args(self, x509): + """Provides the specific create method args for use in universal tests""" + return [x509, self.hostname, self.device_id] + @pytest.mark.it("Uses the provided arguments to create a X509AuthenticationProvider") async def test_auth_provider_creation(self, mocker, client_class, x509): mock_auth_init = mocker.patch("azure.iot.device.iothub.auth.X509AuthenticationProvider") @@ -767,21 +1070,28 @@ class TestIoTHubDeviceClientCreateFromX509Certificate(IoTHubDeviceClientTestsCon ) @pytest.mark.it("Uses the X509AuthenticationProvider to create an IoTHubPipeline") - async def test_pipeline_creation(self, mocker, client_class, x509, mock_pipeline_init): + async def test_pipeline_creation(self, mocker, client_class, x509, mock_mqtt_pipeline_init): mock_auth = mocker.patch( "azure.iot.device.iothub.auth.X509AuthenticationProvider" ).return_value + mock_config = mocker.patch( + "azure.iot.device.iothub.pipeline.IoTHubPipelineConfig" + ).return_value + client_class.create_from_x509_certificate( x509=x509, hostname=self.hostname, device_id=self.device_id ) - assert mock_pipeline_init.call_count == 1 - assert mock_pipeline_init.call_args == mocker.call(mock_auth) + assert mock_mqtt_pipeline_init.call_count == 1 + assert mock_mqtt_pipeline_init.call_args == mocker.call(mock_auth, mock_config) @pytest.mark.it("Uses the IoTHubPipeline to instantiate the client") async def test_client_instantiation(self, mocker, client_class, x509): mock_pipeline = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipeline").return_value + mock_pipeline_http = mocker.patch( + "azure.iot.device.iothub.pipeline.HTTPPipeline" + ).return_value spy_init = mocker.spy(client_class, "__init__") client_class.create_from_x509_certificate( @@ -789,7 +1099,7 @@ class TestIoTHubDeviceClientCreateFromX509Certificate(IoTHubDeviceClientTestsCon ) assert spy_init.call_count == 1 - assert spy_init.call_args == mocker.call(mocker.ANY, mock_pipeline) + assert spy_init.call_args == mocker.call(mocker.ANY, mock_pipeline, mock_pipeline_http) @pytest.mark.it("Returns the instantiated client") async def test_returns_client(self, mocker, client_class, x509): @@ -810,7 +1120,7 @@ class TestIoTHubDeviceClientDisconnect(IoTHubDeviceClientTestsConfig, SharedClie pass -@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - EVENT: Disconnect") +@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - OCCURANCE: Disconnect") class TestIoTHubDeviceClientDisconnectEvent( IoTHubDeviceClientTestsConfig, SharedClientDisconnectEventTests ): @@ -897,6 +1207,158 @@ class TestIoTHubDeviceClientReceiveTwinDesiredPropertiesPatch( pass +@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) -.get_storage_info_for_blob()") +class TestIoTHubDeviceClientGetStorageInfo(IoTHubDeviceClientTestsConfig): + @pytest.mark.it("Begins a 'get_storage_info_for_blob' HTTPPipeline operation") + async def test_calls_pipeline_get_storage_info_for_blob(self, client, http_pipeline): + fake_blob_name = "__fake_blob_name__" + await client.get_storage_info_for_blob(fake_blob_name) + assert http_pipeline.get_storage_info_for_blob.call_count == 1 + assert http_pipeline.get_storage_info_for_blob.call_args[1]["blob_name"] is fake_blob_name + + @pytest.mark.it( + "Waits for the completion of the 'get_storage_info_for_blob' pipeline operation before returning" + ) + async def test_waits_for_pipeline_op_completion(self, mocker, client, http_pipeline): + fake_blob_name = "__fake_blob_name__" + cb_mock = mocker.patch.object(async_adapter, "AwaitableCallback").return_value + cb_mock.completion.return_value = await create_completed_future(None) + + await client.get_storage_info_for_blob(fake_blob_name) + + # Assert callback is sent to pipeline + assert http_pipeline.get_storage_info_for_blob.call_args[1]["callback"] is cb_mock + # Assert callback completion is waited upon + assert cb_mock.completion.call_count == 1 + + @pytest.mark.it( + "Raises a client error if the `get_storage_info_for_blob` pipeline operation calls back with a pipeline error" + ) + @pytest.mark.parametrize( + "pipeline_error,client_error", + [ + pytest.param( + pipeline_exceptions.ProtocolClientError, + client_exceptions.ClientError, + id="ProtocolClientError->ClientError", + ), + pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), + ], + ) + async def test_raises_error_on_pipeline_op_error( + self, mocker, client, http_pipeline, pipeline_error, client_error + ): + fake_blob_name = "__fake_blob_name__" + + my_pipeline_error = pipeline_error() + + def fail_get_storage_info_for_blob(blob_name, callback): + callback(error=my_pipeline_error) + + http_pipeline.get_storage_info_for_blob = mocker.MagicMock( + side_effect=fail_get_storage_info_for_blob + ) + + with pytest.raises(client_error) as e_info: + await client.get_storage_info_for_blob(fake_blob_name) + assert e_info.value.__cause__ is my_pipeline_error + + @pytest.mark.it("Returns a storage_info object upon successful completion") + async def test_returns_storage_info(self, mocker, client, http_pipeline): + fake_blob_name = "__fake_blob_name__" + fake_storage_info = "__fake_storage_info__" + received_storage_info = await client.get_storage_info_for_blob(fake_blob_name) + assert http_pipeline.get_storage_info_for_blob.call_count == 1 + assert http_pipeline.get_storage_info_for_blob.call_args[1]["blob_name"] is fake_blob_name + + assert ( + received_storage_info is fake_storage_info + ) # Note: the return value this is checkign for is defined in client_fixtures.py + + +@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) -.notify_blob_upload_status()") +class TestIoTHubDeviceClientNotifyBlobUploadStatus(IoTHubDeviceClientTestsConfig): + @pytest.mark.it("Begins a 'notify_blob_upload_status' HTTPPipeline operation") + async def test_calls_pipeline_notify_blob_upload_status(self, client, http_pipeline): + correlation_id = "__fake_correlation_id__" + is_success = "__fake_is_success__" + status_code = "__fake_status_code__" + status_description = "__fake_status_description__" + await client.notify_blob_upload_status( + correlation_id, is_success, status_code, status_description + ) + kwargs = http_pipeline.notify_blob_upload_status.call_args[1] + assert http_pipeline.notify_blob_upload_status.call_count == 1 + assert kwargs["correlation_id"] is correlation_id + assert kwargs["is_success"] is is_success + assert kwargs["status_code"] is status_code + assert kwargs["status_description"] is status_description + + @pytest.mark.it( + "Waits for the completion of the 'notify_blob_upload_status' pipeline operation before returning" + ) + async def test_waits_for_pipeline_op_completion(self, mocker, client, http_pipeline): + correlation_id = "__fake_correlation_id__" + is_success = "__fake_is_success__" + status_code = "__fake_status_code__" + status_description = "__fake_status_description__" + cb_mock = mocker.patch.object(async_adapter, "AwaitableCallback").return_value + cb_mock.completion.return_value = await create_completed_future(None) + await client.notify_blob_upload_status( + correlation_id, is_success, status_code, status_description + ) + + # Assert callback is sent to pipeline + assert http_pipeline.notify_blob_upload_status.call_args[1]["callback"] is cb_mock + # Assert callback completion is waited upon + assert cb_mock.completion.call_count == 1 + + @pytest.mark.it( + "Raises a client error if the `notify_blob_upload_status` pipeline operation calls back with a pipeline error" + ) + @pytest.mark.parametrize( + "pipeline_error,client_error", + [ + pytest.param( + pipeline_exceptions.ProtocolClientError, + client_exceptions.ClientError, + id="ProtocolClientError->ClientError", + ), + pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), + ], + ) + async def test_raises_error_on_pipeline_op_error( + self, mocker, client, http_pipeline, pipeline_error, client_error + ): + correlation_id = "__fake_correlation_id__" + is_success = "__fake_is_success__" + status_code = "__fake_status_code__" + status_description = "__fake_status_description__" + my_pipeline_error = pipeline_error() + + def fail_notify_blob_upload_status( + correlation_id, is_success, status_code, status_description, callback + ): + callback(error=my_pipeline_error) + + http_pipeline.notify_blob_upload_status = mocker.MagicMock( + side_effect=fail_notify_blob_upload_status + ) + + with pytest.raises(client_error) as e_info: + await client.notify_blob_upload_status( + correlation_id, is_success, status_code, status_description + ) + assert e_info.value.__cause__ is my_pipeline_error + + +@pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - PROPERTY .connected") +class TestIoTHubDeviceClientPROPERTYConnected( + IoTHubDeviceClientTestsConfig, SharedClientPROPERTYConnectedTests +): + pass + + ################ # MODULE TESTS # ################ @@ -906,11 +1368,11 @@ class IoTHubModuleClientTestsConfig(object): return IoTHubModuleClient @pytest.fixture - def client(self, iothub_pipeline): + def client(self, iothub_pipeline, http_pipeline): """This client automatically resolves callbacks sent to the pipeline. It should be used for the majority of tests. """ - return IoTHubModuleClient(iothub_pipeline) + return IoTHubModuleClient(iothub_pipeline, http_pipeline) @pytest.fixture def connection_string(self, module_connection_string): @@ -930,9 +1392,9 @@ class TestIoTHubModuleClientInstantiation( ): @pytest.mark.it("Sets on_input_message_received handler in the IoTHubPipeline") async def test_sets_on_input_message_received_handler_in_pipeline( - self, client_class, iothub_pipeline + self, client_class, iothub_pipeline, http_pipeline ): - client = client_class(iothub_pipeline) + client = client_class(iothub_pipeline, http_pipeline) assert client._iothub_pipeline.on_input_message_received is not None assert ( @@ -940,48 +1402,106 @@ class TestIoTHubModuleClientInstantiation( == client._inbox_manager.route_input_message ) - @pytest.mark.it( - "Stores the EdgePipeline from the optionally-provided 'edge_pipeline' parameter in the '_edge_pipeline' attribute" - ) - async def test_sets_edge_pipeline_attribute(self, client_class, iothub_pipeline, edge_pipeline): - client = client_class(iothub_pipeline, edge_pipeline) - - assert client._edge_pipeline is edge_pipeline - - @pytest.mark.it( - "Sets the '_edge_pipeline' attribute to None, if the 'edge_pipeline' parameter is not provided" - ) - async def test_edge_pipeline_default_none(self, client_class, iothub_pipeline): - client = client_class(iothub_pipeline) - - assert client._edge_pipeline is None - @pytest.mark.describe("IoTHubModuleClient (Asynchronous) - .create_from_connection_string()") class TestIoTHubModuleClientCreateFromConnectionString( - IoTHubModuleClientTestsConfig, SharedClientCreateFromConnectionStringTests + IoTHubModuleClientTestsConfig, + SharedClientCreateFromConnectionStringTests, + SharedClientCreateMethodUserOptionTests, ): - pass + @pytest.fixture + def client_create_method(self, client_class): + """Provides the specific create method for use in universal tests""" + return client_class.create_from_connection_string + + @pytest.fixture + def create_method_args(self, connection_string): + """Provides the specific create method args for use in universal tests""" + return [connection_string] -@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - .create_from_shared_access_signature()") -class TestIoTHubModuleClientCreateFromSharedAccessSignature( - IoTHubModuleClientTestsConfig, SharedClientCreateFromSharedAccessSignature +class IoTHubModuleClientClientCreateFromEdgeEnvironmentUserOptionTests( + SharedClientCreateMethodUserOptionTests ): - pass + """This class inherites the user option tests shared by all create method APIs, and overrides + tests in order to accomodate unique requirements for the .create_from_edge_enviornment() method. + + Because .create_from_edge_environment() tests are spread accross multiple test units + (i.e. test classes), these overrides are done in this class, which is then inherited by all + .create_from_edge_environment() test units below. + """ + + @pytest.fixture + def client_create_method(self, client_class): + """Provides the specific create method for use in universal tests""" + return client_class.create_from_edge_environment + + @pytest.fixture + def create_method_args(self): + """Provides the specific create method args for use in universal tests""" + return [] + + @pytest.mark.it( + "Raises a TypeError if the 'server_verification_cert' user option parameter is provided" + ) + async def test_server_verification_cert_option( + self, + option_test_required_patching, + client_create_method, + create_method_args, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + """THIS TEST OVERRIDES AN INHERITED TEST""" + + with pytest.raises(TypeError): + client_create_method( + *create_method_args, server_verification_cert="fake_server_verification_cert" + ) + + @pytest.mark.it("Sets default user options if none are provided") + async def test_default_options( + self, + option_test_required_patching, + client_create_method, + create_method_args, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + """THIS TEST OVERRIDES AN INHERITED TEST""" + client_create_method(*create_method_args) + + # Get configuration object, and ensure it was used for both protocol pipelines + assert mock_mqtt_pipeline_init.call_count == 1 + config = mock_mqtt_pipeline_init.call_args[0][1] + assert config == mock_http_pipeline_init.call_args[0][1] + + # Get auth provider object, and ensure it was used for both protocol pipelines + auth = mock_mqtt_pipeline_init.call_args[0][0] + assert auth == mock_http_pipeline_init.call_args[0][0] + + assert config.product_info == "" + assert not config.websockets + assert not config.cipher @pytest.mark.describe( "IoTHubModuleClient (Asynchronous) - .create_from_edge_environment() -- Edge Container Environment" ) class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithContainerEnv( - IoTHubModuleClientTestsConfig + IoTHubModuleClientTestsConfig, IoTHubModuleClientClientCreateFromEdgeEnvironmentUserOptionTests ): + @pytest.fixture + def option_test_required_patching(self, mocker, edge_container_environment): + """THIS FIXTURE OVERRIDES AN INHERITED FIXTURE""" + mocker.patch("azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider") + mocker.patch.dict(os.environ, edge_container_environment, clear=True) + @pytest.mark.it( "Uses Edge container environment variables to create an IoTEdgeAuthenticationProvider" ) async def test_auth_provider_creation(self, mocker, client_class, edge_container_environment): - mocker.patch.dict(os.environ, edge_container_environment) + mocker.patch.dict(os.environ, edge_container_environment, clear=True) mock_auth_init = mocker.patch("azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider") client_class.create_from_edge_environment() @@ -1006,7 +1526,7 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithContainerEnv( # This test verifies that with a hybrid environment, the auth provider will always be # an IoTEdgeAuthenticationProvider, even if local debug variables are present hybrid_environment = {**edge_container_environment, **edge_local_debug_environment} - mocker.patch.dict(os.environ, hybrid_environment) + mocker.patch.dict(os.environ, hybrid_environment, clear=True) mock_edge_auth_init = mocker.patch( "azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider" ) @@ -1029,33 +1549,37 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithContainerEnv( ) @pytest.mark.it( - "Uses the IoTEdgeAuthenticationProvider to create an IoTHubPipeline and an EdgePipeline" + "Uses the IoTEdgeAuthenticationProvider to create an IoTHubPipeline and an HTTPPipeline" ) async def test_pipeline_creation(self, mocker, client_class, edge_container_environment): - mocker.patch.dict(os.environ, edge_container_environment) + mocker.patch.dict(os.environ, edge_container_environment, clear=True) mock_auth = mocker.patch( "azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider" ).return_value - mock_iothub_pipeline_init = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipeline") - mock_edge_pipeline_init = mocker.patch("azure.iot.device.iothub.pipeline.EdgePipeline") + mock_config = mocker.patch( + "azure.iot.device.iothub.pipeline.IoTHubPipelineConfig" + ).return_value + + mock_mqtt_pipeline_init = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipeline") + mock_http_pipeline_init = mocker.patch("azure.iot.device.iothub.pipeline.HTTPPipeline") client_class.create_from_edge_environment() - assert mock_iothub_pipeline_init.call_count == 1 - assert mock_iothub_pipeline_init.call_args == mocker.call(mock_auth) - assert mock_edge_pipeline_init.call_count == 1 - assert mock_edge_pipeline_init.call_args == mocker.call(mock_auth) + assert mock_mqtt_pipeline_init.call_count == 1 + assert mock_mqtt_pipeline_init.call_args == mocker.call(mock_auth, mock_config) + assert mock_http_pipeline_init.call_count == 1 + assert mock_http_pipeline_init.call_args == mocker.call(mock_auth, mock_config) - @pytest.mark.it("Uses the IoTHubPipeline and the EdgePipeline to instantiate the client") + @pytest.mark.it("Uses the IoTHubPipeline and the HTTPPipeline to instantiate the client") async def test_client_instantiation(self, mocker, client_class, edge_container_environment): - mocker.patch.dict(os.environ, edge_container_environment) + mocker.patch.dict(os.environ, edge_container_environment, clear=True) # Always patch the IoTEdgeAuthenticationProvider to prevent I/O operations mocker.patch("azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider") mock_iothub_pipeline = mocker.patch( "azure.iot.device.iothub.pipeline.IoTHubPipeline" ).return_value - mock_edge_pipeline = mocker.patch( - "azure.iot.device.iothub.pipeline.EdgePipeline" + mock_http_pipeline = mocker.patch( + "azure.iot.device.iothub.pipeline.HTTPPipeline" ).return_value spy_init = mocker.spy(client_class, "__init__") @@ -1063,12 +1587,12 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithContainerEnv( assert spy_init.call_count == 1 assert spy_init.call_args == mocker.call( - mocker.ANY, mock_iothub_pipeline, edge_pipeline=mock_edge_pipeline + mocker.ANY, mock_iothub_pipeline, mock_http_pipeline ) @pytest.mark.it("Returns the instantiated client") async def test_returns_client(self, mocker, client_class, edge_container_environment): - mocker.patch.dict(os.environ, edge_container_environment) + mocker.patch.dict(os.environ, edge_container_environment, clear=True) # Always patch the IoTEdgeAuthenticationProvider to prevent I/O operations mocker.patch("azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider") @@ -1076,7 +1600,7 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithContainerEnv( assert isinstance(client, client_class) - @pytest.mark.it("Raises IoTEdgeError if the environment is missing required variables") + @pytest.mark.it("Raises OSError if the environment is missing required variables") @pytest.mark.parametrize( "missing_env_var", [ @@ -1094,37 +1618,47 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithContainerEnv( ): # Remove a variable from the fixture del edge_container_environment[missing_env_var] - mocker.patch.dict(os.environ, edge_container_environment) + mocker.patch.dict(os.environ, edge_container_environment, clear=True) - with pytest.raises(IoTEdgeError): + with pytest.raises(OSError): client_class.create_from_edge_environment() - @pytest.mark.it("Raises IoTEdgeError if there is an error using the Edge for authentication") + @pytest.mark.it("Raises OSError if there is an error using the Edge for authentication") async def test_bad_edge_auth(self, mocker, client_class, edge_container_environment): - mocker.patch.dict(os.environ, edge_container_environment) + mocker.patch.dict(os.environ, edge_container_environment, clear=True) mock_auth = mocker.patch("azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider") - mock_auth.side_effect = IoTEdgeError - - with pytest.raises(IoTEdgeError): + error = IoTEdgeError() + mock_auth.side_effect = error + with pytest.raises(OSError) as e_info: client_class.create_from_edge_environment() + assert e_info.value.__cause__ is error @pytest.mark.describe( "IoTHubModuleClient (Asynchronous) - .create_from_edge_environment() -- Edge Local Debug Environment" ) -class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnv(IoTHubModuleClientTestsConfig): +class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnv( + IoTHubModuleClientTestsConfig, IoTHubModuleClientClientCreateFromEdgeEnvironmentUserOptionTests +): + @pytest.fixture + def option_test_required_patching(self, mocker, edge_local_debug_environment): + """THIS FIXTURE OVERRIDES AN INHERITED FIXTURE""" + mocker.patch("azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider") + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) + mocker.patch.object(io, "open") + @pytest.fixture def mock_open(self, mocker): return mocker.patch.object(io, "open") @pytest.mark.it( - "Extracts the CA certificate from the file indicated by the EdgeModuleCACertificateFile environment variable" + "Extracts the server verification certificate from the file indicated by the EdgeModuleCACertificateFile environment variable" ) - async def test_read_ca_cert( + async def test_read_server_verification_cert( self, mocker, client_class, edge_local_debug_environment, mock_open ): mock_file_handle = mock_open.return_value.__enter__.return_value - mocker.patch.dict(os.environ, edge_local_debug_environment) + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) client_class.create_from_edge_environment() assert mock_open.call_count == 1 assert mock_open.call_args == mocker.call( @@ -1133,13 +1667,13 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnv(IoTHubModuleCl assert mock_file_handle.read.call_count == 1 @pytest.mark.it( - "Uses Edge local debug environment variables to create a SymmetricKeyAuthenticationProvider (with CA cert)" + "Uses Edge local debug environment variables to create a SymmetricKeyAuthenticationProvider (with server verification cert)" ) async def test_auth_provider_creation( self, mocker, client_class, edge_local_debug_environment, mock_open ): expected_cert = mock_open.return_value.__enter__.return_value.read.return_value - mocker.patch.dict(os.environ, edge_local_debug_environment) + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) mock_auth_parse = mocker.patch( "azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider" ).parse @@ -1150,7 +1684,7 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnv(IoTHubModuleCl assert mock_auth_parse.call_args == mocker.call( edge_local_debug_environment["EdgeHubConnectionString"] ) - assert mock_auth_parse.return_value.ca_cert == expected_cert + assert mock_auth_parse.return_value.server_verification_cert == expected_cert @pytest.mark.it( "Only uses Edge local debug variables if no Edge container variables are present in the environment" @@ -1166,7 +1700,7 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnv(IoTHubModuleCl # This test verifies that with a hybrid environment, the auth provider will always be # an IoTEdgeAuthenticationProvider, even if local debug variables are present hybrid_environment = {**edge_container_environment, **edge_local_debug_environment} - mocker.patch.dict(os.environ, hybrid_environment) + mocker.patch.dict(os.environ, hybrid_environment, clear=True) mock_edge_auth_init = mocker.patch( "azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider" ) @@ -1189,35 +1723,38 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnv(IoTHubModuleCl ) @pytest.mark.it( - "Uses the SymmetricKeyAuthenticationProvider to create an IoTHubPipeline and an EdgePipeline" + "Uses the SymmetricKeyAuthenticationProvider to create an IoTHubPipeline and an HTTPPipeline" ) async def test_pipeline_creation( self, mocker, client_class, edge_local_debug_environment, mock_open ): - mocker.patch.dict(os.environ, edge_local_debug_environment) + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) mock_auth = mocker.patch( "azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider" ).parse.return_value - mock_iothub_pipeline_init = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipeline") - mock_edge_pipeline_init = mocker.patch("azure.iot.device.iothub.pipeline.EdgePipeline") + mock_config = mocker.patch( + "azure.iot.device.iothub.pipeline.IoTHubPipelineConfig" + ).return_value + mock_mqtt_pipeline_init = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipeline") + mock_http_pipeline_init = mocker.patch("azure.iot.device.iothub.pipeline.HTTPPipeline") client_class.create_from_edge_environment() - assert mock_iothub_pipeline_init.call_count == 1 - assert mock_iothub_pipeline_init.call_args == mocker.call(mock_auth) - assert mock_edge_pipeline_init.call_count == 1 - assert mock_iothub_pipeline_init.call_args == mocker.call(mock_auth) + assert mock_mqtt_pipeline_init.call_count == 1 + assert mock_mqtt_pipeline_init.call_args == mocker.call(mock_auth, mock_config) + assert mock_http_pipeline_init.call_count == 1 + assert mock_http_pipeline_init.call_args == mocker.call(mock_auth, mock_config) - @pytest.mark.it("Uses the IoTHubPipeline and the EdgePipeline to instantiate the client") + @pytest.mark.it("Uses the IoTHubPipeline and the HTTPPipeline to instantiate the client") async def test_client_instantiation( self, mocker, client_class, edge_local_debug_environment, mock_open ): - mocker.patch.dict(os.environ, edge_local_debug_environment) + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) mock_iothub_pipeline = mocker.patch( "azure.iot.device.iothub.pipeline.IoTHubPipeline" ).return_value - mock_edge_pipeline = mocker.patch( - "azure.iot.device.iothub.pipeline.EdgePipeline" + mock_http_pipeline = mocker.patch( + "azure.iot.device.iothub.pipeline.HTTPPipeline" ).return_value spy_init = mocker.spy(client_class, "__init__") @@ -1225,20 +1762,20 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnv(IoTHubModuleCl assert spy_init.call_count == 1 assert spy_init.call_args == mocker.call( - mocker.ANY, mock_iothub_pipeline, edge_pipeline=mock_edge_pipeline + mocker.ANY, mock_iothub_pipeline, mock_http_pipeline ) @pytest.mark.it("Returns the instantiated client") async def test_returns_client( self, mocker, client_class, edge_local_debug_environment, mock_open ): - mocker.patch.dict(os.environ, edge_local_debug_environment) + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) client = client_class.create_from_edge_environment() assert isinstance(client, client_class) - @pytest.mark.it("Raises IoTEdgeError if the environment is missing required variables") + @pytest.mark.it("Raises OSError if the environment is missing required variables") @pytest.mark.parametrize( "missing_env_var", ["EdgeHubConnectionString", "EdgeModuleCACertificateFile"] ) @@ -1247,9 +1784,9 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnv(IoTHubModuleCl ): # Remove a variable from the fixture del edge_local_debug_environment[missing_env_var] - mocker.patch.dict(os.environ, edge_local_debug_environment) + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - with pytest.raises(IoTEdgeError): + with pytest.raises(OSError): client_class.create_from_edge_environment() # TODO: If auth package was refactored to use ConnectionString class, tests from that @@ -1273,7 +1810,7 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnv(IoTHubModuleCl self, mocker, client_class, edge_local_debug_environment, bad_cs, mock_open ): edge_local_debug_environment["EdgeHubConnectionString"] = bad_cs - mocker.patch.dict(os.environ, edge_local_debug_environment) + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) with pytest.raises(ValueError): client_class.create_from_edge_environment() @@ -1284,27 +1821,43 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnv(IoTHubModuleCl async def test_bad_filepath( self, mocker, client_class, edge_local_debug_environment, mock_open ): - mocker.patch.dict(os.environ, edge_local_debug_environment) - mock_open.side_effect = FileNotFoundError - with pytest.raises(ValueError): + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) + error = FileNotFoundError() + mock_open.side_effect = error + with pytest.raises(ValueError) as e_info: client_class.create_from_edge_environment() + assert e_info.value.__cause__ is error @pytest.mark.it( "Raises ValueError if the file referenced by the filepath in the EdgeModuleCACertificateFile environment variable cannot be opened" ) async def test_bad_file_io(self, mocker, client_class, edge_local_debug_environment, mock_open): - mocker.patch.dict(os.environ, edge_local_debug_environment) - mock_open.side_effect = OSError - with pytest.raises(ValueError): + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) + error = OSError() + mock_open.side_effect = error + with pytest.raises(ValueError) as e_info: client_class.create_from_edge_environment() + assert e_info.value.__cause__ is error @pytest.mark.describe("IoTHubModuleClient (Asynchronous) - .create_from_x509_certificate()") -class TestIoTHubModuleClientCreateFromX509Certificate(IoTHubModuleClientTestsConfig): +class TestIoTHubModuleClientCreateFromX509Certificate( + IoTHubModuleClientTestsConfig, SharedClientCreateMethodUserOptionTests +): hostname = "durmstranginstitute.farend" device_id = "MySnitch" module_id = "Charms" + @pytest.fixture + def client_create_method(self, client_class): + """Provides the specific create method for use in universal tests""" + return client_class.create_from_x509_certificate + + @pytest.fixture + def create_method_args(self, x509): + """Provides the specific create method args for use in universal tests""" + return [x509, self.hostname, self.device_id, self.module_id] + @pytest.mark.it("Uses the provided arguments to create a X509AuthenticationProvider") async def test_auth_provider_creation(self, mocker, client_class, x509): mock_auth_init = mocker.patch("azure.iot.device.iothub.auth.X509AuthenticationProvider") @@ -1319,21 +1872,28 @@ class TestIoTHubModuleClientCreateFromX509Certificate(IoTHubModuleClientTestsCon ) @pytest.mark.it("Uses the X509AuthenticationProvider to create an IoTHubPipeline") - async def test_pipeline_creation(self, mocker, client_class, x509, mock_pipeline_init): + async def test_pipeline_creation(self, mocker, client_class, x509, mock_mqtt_pipeline_init): mock_auth = mocker.patch( "azure.iot.device.iothub.auth.X509AuthenticationProvider" ).return_value + mock_config = mocker.patch( + "azure.iot.device.iothub.pipeline.IoTHubPipelineConfig" + ).return_value + client_class.create_from_x509_certificate( x509=x509, hostname=self.hostname, device_id=self.device_id, module_id=self.module_id ) - assert mock_pipeline_init.call_count == 1 - assert mock_pipeline_init.call_args == mocker.call(mock_auth) + assert mock_mqtt_pipeline_init.call_count == 1 + assert mock_mqtt_pipeline_init.call_args == mocker.call(mock_auth, mock_config) @pytest.mark.it("Uses the IoTHubPipeline to instantiate the client") async def test_client_instantiation(self, mocker, client_class, x509): mock_pipeline = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipeline").return_value + mock_pipeline_http = mocker.patch( + "azure.iot.device.iothub.pipeline.HTTPPipeline" + ).return_value spy_init = mocker.spy(client_class, "__init__") client_class.create_from_x509_certificate( @@ -1341,7 +1901,7 @@ class TestIoTHubModuleClientCreateFromX509Certificate(IoTHubModuleClientTestsCon ) assert spy_init.call_count == 1 - assert spy_init.call_args == mocker.call(mocker.ANY, mock_pipeline) + assert spy_init.call_args == mocker.call(mocker.ANY, mock_pipeline, mock_pipeline_http) @pytest.mark.it("Returns the instantiated client") async def test_returns_client(self, mocker, client_class, x509): @@ -1362,7 +1922,7 @@ class TestIoTHubModuleClientDisconnect(IoTHubModuleClientTestsConfig, SharedClie pass -@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - EVENT: Disconnect") +@pytest.mark.describe("IoTHubModuleClient (Asynchronous) - OCCURANCE: Disconnect") class TestIoTHubModuleClientDisconnectEvent( IoTHubModuleClientTestsConfig, SharedClientDisconnectEventTests ): @@ -1402,19 +1962,47 @@ class TestIoTHubModuleClientSendToOutput(IoTHubModuleClientTestsConfig): assert cb_mock.completion.call_count == 1 @pytest.mark.it( - "Raises an error if the `send_output_event` pipeline operation calls back with an error" + "Raises a client error if the `send_output_event` pipeline operation calls back with a pipeline error" + ) + @pytest.mark.parametrize( + "pipeline_error,client_error", + [ + pytest.param( + pipeline_exceptions.ConnectionDroppedError, + client_exceptions.ConnectionDroppedError, + id="ConnectionDroppedError->ConnectionDroppedError", + ), + pytest.param( + pipeline_exceptions.ConnectionFailedError, + client_exceptions.ConnectionFailedError, + id="ConnectionFailedError->ConnectionFailedError", + ), + pytest.param( + pipeline_exceptions.UnauthorizedError, + client_exceptions.CredentialError, + id="UnauthorizedError->CredentialError", + ), + pytest.param( + pipeline_exceptions.ProtocolClientError, + client_exceptions.ClientError, + id="ProtocolClientError->ClientError", + ), + pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), + ], ) async def test_raises_error_on_pipeline_op_error( - self, mocker, client, iothub_pipeline, message, fake_error + self, mocker, client, iothub_pipeline, message, pipeline_error, client_error ): + my_pipeline_error = pipeline_error() + def fail_send_output_event(message, callback): - callback(error=fake_error) + callback(error=my_pipeline_error) iothub_pipeline.send_output_event = mocker.MagicMock(side_effect=fail_send_output_event) - with pytest.raises(fake_error.__class__) as e_info: + with pytest.raises(client_error) as e_info: output_name = "some_output" await client.send_message_to_output(message, output_name) - assert e_info.value is fake_error + assert e_info.value.__cause__ is my_pipeline_error assert iothub_pipeline.send_output_event.call_count == 1 @pytest.mark.it( @@ -1441,6 +2029,50 @@ class TestIoTHubModuleClientSendToOutput(IoTHubModuleClientTestsConfig): assert isinstance(sent_message, Message) assert sent_message.data == message_input + @pytest.mark.it("Raises error when message data size is greater than 256 KB") + async def test_raises_error_when_message_to_output_data_greater_than_256( + self, client, iothub_pipeline + ): + output_name = "some_output" + data_input = "serpensortia" * 256000 + message = Message(data_input) + with pytest.raises(ValueError) as e_info: + await client.send_message_to_output(message, output_name) + assert "256 KB" in e_info.value.args[0] + assert iothub_pipeline.send_output_event.call_count == 0 + + @pytest.mark.it("Raises error when message size is greater than 256 KB") + async def test_raises_error_when_message_to_output_size_greater_than_256( + self, client, iothub_pipeline + ): + output_name = "some_output" + data_input = "serpensortia" + message = Message(data_input) + message.custom_properties["spell"] = data_input * 256000 + with pytest.raises(ValueError) as e_info: + await client.send_message_to_output(message, output_name) + assert "256 KB" in e_info.value.args[0] + assert iothub_pipeline.send_output_event.call_count == 0 + + @pytest.mark.it("Does not raises error when message data size is equal to 256 KB") + async def test_raises_error_when_message_to_output_data_equal_to_256( + self, client, iothub_pipeline + ): + output_name = "some_output" + data_input = "a" * 262095 + message = Message(data_input) + # This check was put as message class may undergo the default content type encoding change + # and the above calculation will change. + if message.get_size() != device_constant.TELEMETRY_MESSAGE_SIZE_LIMIT: + assert False + + await client.send_message_to_output(message, output_name) + + assert iothub_pipeline.send_output_event.call_count == 1 + sent_message = iothub_pipeline.send_output_event.call_args[0][0] + assert isinstance(sent_message, Message) + assert sent_message.data == data_input + @pytest.mark.describe("IoTHubModuleClient (Asynchronous) - .receive_message_on_input()") class TestIoTHubModuleClientReceiveInputMessage(IoTHubModuleClientTestsConfig): @@ -1520,3 +2152,85 @@ class TestIoTHubModuleClientReceiveTwinDesiredPropertiesPatch( IoTHubModuleClientTestsConfig, SharedClientReceiveTwinDesiredPropertiesPatchTests ): pass + + +@pytest.mark.describe("IoTHubModuleClient (Synchronous) -.invoke_method()") +class TestIoTHubModuleClientInvokeMethod(IoTHubModuleClientTestsConfig): + @pytest.mark.it("Begins a 'invoke_method' HTTPPipeline operation where the target is a device") + async def test_calls_pipeline_invoke_method_for_device(self, mocker, client, http_pipeline): + method_params = "__fake_method_params__" + device_id = "__fake_device_id__" + await client.invoke_method(method_params, device_id) + assert http_pipeline.invoke_method.call_count == 1 + assert http_pipeline.invoke_method.call_args == mocker.call( + device_id, method_params, callback=mocker.ANY, module_id=None + ) + + @pytest.mark.it("Begins a 'invoke_method' HTTPPipeline operation where the target is a module") + async def test_calls_pipeline_invoke_method_for_module(self, mocker, client, http_pipeline): + method_params = "__fake_method_params__" + device_id = "__fake_device_id__" + module_id = "__fake_module_id__" + await client.invoke_method(method_params, device_id, module_id=module_id) + assert http_pipeline.invoke_method.call_count == 1 + # assert http_pipeline.invoke_method.call_args[0][0] is device_id + # assert http_pipeline.invoke_method.call_args[0][1] is method_params + assert http_pipeline.invoke_method.call_args == mocker.call( + device_id, method_params, callback=mocker.ANY, module_id=module_id + ) + + @pytest.mark.it( + "Waits for the completion of the 'invoke_method' pipeline operation before returning" + ) + async def test_waits_for_pipeline_op_completion(self, mocker, client, http_pipeline): + method_params = "__fake_method_params__" + device_id = "__fake_device_id__" + module_id = "__fake_module_id__" + cb_mock = mocker.patch.object(async_adapter, "AwaitableCallback").return_value + cb_mock.completion.return_value = await create_completed_future(None) + + await client.invoke_method(method_params, device_id, module_id=module_id) + + # Assert callback is sent to pipeline + assert http_pipeline.invoke_method.call_args[1]["callback"] is cb_mock + # Assert callback completion is waited upon + assert cb_mock.completion.call_count == 1 + + @pytest.mark.it( + "Raises a client error if the `invoke_method` pipeline operation calls back with a pipeline error" + ) + @pytest.mark.parametrize( + "pipeline_error,client_error", + [ + pytest.param( + pipeline_exceptions.ProtocolClientError, + client_exceptions.ClientError, + id="ProtocolClientError->ClientError", + ), + pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), + ], + ) + async def test_raises_error_on_pipeline_op_error( + self, mocker, client, http_pipeline, pipeline_error, client_error + ): + method_params = "__fake_method_params__" + device_id = "__fake_device_id__" + module_id = "__fake_module_id__" + my_pipeline_error = pipeline_error() + + def fail_invoke_method(method_params, device_id, callback, module_id=None): + return callback(error=my_pipeline_error) + + http_pipeline.invoke_method = mocker.MagicMock(side_effect=fail_invoke_method) + + with pytest.raises(client_error) as e_info: + await client.invoke_method(method_params, device_id, module_id=module_id) + + assert e_info.value.__cause__ is my_pipeline_error + + +@pytest.mark.describe("IoTHubModule (Asynchronous) - PROPERTY .connected") +class TestIoTHubModuleClientPROPERTYConnected( + IoTHubModuleClientTestsConfig, SharedClientPROPERTYConnectedTests +): + pass diff --git a/azure-iot-device/tests/iothub/auth/test_base_renewable_token_authentication_provider.py b/azure-iot-device/tests/iothub/auth/test_base_renewable_token_authentication_provider.py index 2c13c696f..a3e40a453 100644 --- a/azure-iot-device/tests/iothub/auth/test_base_renewable_token_authentication_provider.py +++ b/azure-iot-device/tests/iothub/auth/test_base_renewable_token_authentication_provider.py @@ -49,16 +49,12 @@ class FakeAuthProvider(BaseRenewableTokenAuthenticationProvider): @pytest.fixture(scope="function") def device_auth_provider(): - auth_provider = FakeAuthProvider(fake_hostname, fake_device_id, None) - yield auth_provider - auth_provider.disconnect() + return FakeAuthProvider(fake_hostname, fake_device_id, None) @pytest.fixture(scope="function") def module_auth_provider(): - auth_provider = FakeAuthProvider(fake_hostname, fake_device_id, fake_module_id) - yield auth_provider - auth_provider.disconnect() + return FakeAuthProvider(fake_hostname, fake_device_id, fake_module_id) @pytest.fixture(scope="function") @@ -102,13 +98,14 @@ def test_get_current_sas_token_returns_existing_sas_token(device_auth_provider): assert token1 == token2 -def test_generate_new_sas_token_calls_on_sas_token_updated_handler_when_sas_udpates( +def test_generate_new_sas_token_calls_on_sas_token_updated_handler_when_sas_updates( device_auth_provider ): - update_callback = MagicMock() - device_auth_provider.on_sas_token_updated_handler = update_callback + update_callback_list = [MagicMock(), MagicMock(), MagicMock()] + device_auth_provider.on_sas_token_updated_handler_list = update_callback_list device_auth_provider.generate_new_sas_token() - update_callback.assert_called_once_with() + for x in update_callback_list: + x.assert_called_once_with() def test_device_generate_new_sas_token_calls_sign_with_correct_default_args( @@ -163,17 +160,21 @@ def test_generate_new_sas_token_cancels_and_reschedules_update_timer_with_correc def test_update_timer_generates_new_sas_token_and_calls_on_sas_token_updated_handler( device_auth_provider, fake_timer_object ): - update_callback = MagicMock() + update_callback_list = [MagicMock(), MagicMock(), MagicMock()] device_auth_provider.generate_new_sas_token() - device_auth_provider.on_sas_token_updated_handler = update_callback + device_auth_provider.on_sas_token_updated_handler_list = update_callback_list timer_callback = fake_timer_object.call_args[0][1] device_auth_provider._sign.reset_mock() timer_callback() - update_callback.assert_called_once_with() + for x in update_callback_list: + x.assert_called_once_with() assert device_auth_provider._sign.call_count == 1 -def test_disconnect_cancels_update_timer(device_auth_provider, fake_timer_object): +def test_finalizer_cancels_update_timer(fake_timer_object): + # can't use the device_auth_provider fixture here because the fixture adds + # to the object refcount and prevents del from calling the finalizer + device_auth_provider = FakeAuthProvider(fake_hostname, fake_device_id, None) device_auth_provider.generate_new_sas_token() - device_auth_provider.disconnect() + del device_auth_provider fake_timer_object.return_value.cancel.assert_called_once_with() diff --git a/azure-iot-device/tests/iothub/auth/test_iotedge_authentication_provider.py b/azure-iot-device/tests/iothub/auth/test_iotedge_authentication_provider.py index e894007cc..093243997 100644 --- a/azure-iot-device/tests/iothub/auth/test_iotedge_authentication_provider.py +++ b/azure-iot-device/tests/iothub/auth/test_iotedge_authentication_provider.py @@ -16,7 +16,7 @@ from azure.iot.device.iothub.auth.iotedge_authentication_provider import ( IoTEdgeError, ) from .shared_auth_tests import SharedBaseRenewableAuthenticationProviderInstantiationTests -from azure.iot.device import constant +from azure.iot.device.product_info import ProductInfo logging.basicConfig(level=logging.DEBUG) @@ -108,10 +108,10 @@ class TestIoTEdgeAuthenticationProviderInstantiation( assert auth_provider.hsm is mock_hsm @pytest.mark.it( - "Sets a certificate acquired from the IoTEdgeHsm as the ca_cert instance attribute" + "Sets a certificate acquired from the IoTEdgeHsm as the server_verification_cert instance attribute" ) - def test_ca_cert_from_edge_hsm(self, auth_provider, mock_hsm): - assert auth_provider.ca_cert is mock_hsm.get_trust_bundle.return_value + def test_server_verification_cert_from_edge_hsm(self, auth_provider, mock_hsm): + assert auth_provider.server_verification_cert is mock_hsm.get_trust_bundle.return_value assert mock_hsm.get_trust_bundle.call_count == 1 @@ -192,7 +192,9 @@ class TestIoTEdgeHsmGetTrustBundle(object): mock_request_get = mocker.patch.object(requests, "get") expected_url = hsm.workload_uri + "trust-bundle" expected_params = {"api-version": hsm.api_version} - expected_headers = {"User-Agent": urllib.parse.quote_plus(constant.USER_AGENT)} + expected_headers = { + "User-Agent": urllib.parse.quote_plus(ProductInfo.get_iothub_user_agent()) + } hsm.get_trust_bundle() @@ -215,19 +217,23 @@ class TestIoTEdgeHsmGetTrustBundle(object): def test_bad_request(self, mocker, hsm): mock_request_get = mocker.patch.object(requests, "get") mock_response = mock_request_get.return_value - mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError + error = requests.exceptions.HTTPError() + mock_response.raise_for_status.side_effect = error - with pytest.raises(IoTEdgeError): + with pytest.raises(IoTEdgeError) as e_info: hsm.get_trust_bundle() + assert e_info.value.__cause__ is error @pytest.mark.it("Raises IoTEdgeError if there is an error in json decoding the trust bundle") def test_bad_json(self, mocker, hsm): mock_request_get = mocker.patch.object(requests, "get") mock_response = mock_request_get.return_value - mock_response.json.side_effect = ValueError + error = ValueError() + mock_response.json.side_effect = error - with pytest.raises(IoTEdgeError): + with pytest.raises(IoTEdgeError) as e_info: hsm.get_trust_bundle() + assert e_info.value.__cause__ is error @pytest.mark.it("Raises IoTEdgeError if the certificate is missing from the trust bundle") def test_bad_trust_bundle(self, mocker, hsm): @@ -254,7 +260,9 @@ class TestIoTEdgeHsmSign(object): module_generation_id=hsm.module_generation_id, ) expected_params = {"api-version": hsm.api_version} - expected_headers = {"User-Agent": urllib.parse.quote_plus(constant.USER_AGENT)} + expected_headers = { + "User-Agent": urllib.parse.quote_plus(ProductInfo.get_iothub_user_agent()) + } expected_json = json.dumps({"keyId": "primary", "algo": "HMACSHA256", "data": data_str_b64}) hsm.sign(data_str) @@ -306,19 +314,22 @@ class TestIoTEdgeHsmSign(object): def test_bad_request(self, mocker, hsm): mock_request_post = mocker.patch.object(requests, "post") mock_response = mock_request_post.return_value - mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError + error = requests.exceptions.HTTPError() + mock_response.raise_for_status.side_effect = error - with pytest.raises(IoTEdgeError): + with pytest.raises(IoTEdgeError) as e_info: hsm.sign("somedata") + assert e_info.value.__cause__ is error @pytest.mark.it("Raises IoTEdgeError if there is an error in json decoding the signed response") def test_bad_json(self, mocker, hsm): mock_request_post = mocker.patch.object(requests, "post") mock_response = mock_request_post.return_value - mock_response.json.side_effect = ValueError - - with pytest.raises(IoTEdgeError): + error = ValueError() + mock_response.json.side_effect = error + with pytest.raises(IoTEdgeError) as e_info: hsm.sign("somedata") + assert e_info.value.__cause__ is error @pytest.mark.it("Raises IoTEdgeError if the signed data is missing from the response") def test_bad_response(self, mocker, hsm): diff --git a/azure-iot-device/tests/iothub/auth/test_sk_authentication_provider.py b/azure-iot-device/tests/iothub/auth/test_sk_authentication_provider.py index 77e5f613e..2f9f31dce 100644 --- a/azure-iot-device/tests/iothub/auth/test_sk_authentication_provider.py +++ b/azure-iot-device/tests/iothub/auth/test_sk_authentication_provider.py @@ -33,13 +33,10 @@ def test_all_attributes_for_device(): hostname, device_id, shared_access_key ) sym_key_auth_provider = SymmetricKeyAuthenticationProvider.parse(connection_string) - try: - assert sym_key_auth_provider.hostname == hostname - assert sym_key_auth_provider.device_id == device_id - assert hostname in sym_key_auth_provider.get_current_sas_token() - assert device_id in sym_key_auth_provider.get_current_sas_token() - finally: - sym_key_auth_provider.disconnect() + + assert sym_key_auth_provider.device_id == device_id + assert hostname in sym_key_auth_provider.get_current_sas_token() + assert device_id in sym_key_auth_provider.get_current_sas_token() def test_all_attributes_for_module(): @@ -47,15 +44,13 @@ def test_all_attributes_for_module(): hostname, device_id, module_id, shared_access_key ) sym_key_auth_provider = SymmetricKeyAuthenticationProvider.parse(connection_string) - try: - assert sym_key_auth_provider.hostname == hostname - assert sym_key_auth_provider.device_id == device_id - assert sym_key_auth_provider.module_id == module_id - assert hostname in sym_key_auth_provider.get_current_sas_token() - assert device_id in sym_key_auth_provider.get_current_sas_token() - assert module_id in sym_key_auth_provider.get_current_sas_token() - finally: - sym_key_auth_provider.disconnect() + + assert sym_key_auth_provider.hostname == hostname + assert sym_key_auth_provider.device_id == device_id + assert sym_key_auth_provider.module_id == module_id + assert hostname in sym_key_auth_provider.get_current_sas_token() + assert device_id in sym_key_auth_provider.get_current_sas_token() + assert module_id in sym_key_auth_provider.get_current_sas_token() def test_sastoken_keyname_device(): @@ -65,12 +60,9 @@ def test_sastoken_keyname_device(): sym_key_auth_provider = SymmetricKeyAuthenticationProvider.parse(connection_string) - try: - assert hostname in sym_key_auth_provider.get_current_sas_token() - assert device_id in sym_key_auth_provider.get_current_sas_token() - assert shared_access_key_name in sym_key_auth_provider.get_current_sas_token() - finally: - sym_key_auth_provider.disconnect() + assert hostname in sym_key_auth_provider.get_current_sas_token() + assert device_id in sym_key_auth_provider.get_current_sas_token() + assert shared_access_key_name in sym_key_auth_provider.get_current_sas_token() def test_raises_when_auth_provider_created_from_empty_connection_string(): diff --git a/azure-iot-device/tests/iothub/client_fixtures.py b/azure-iot-device/tests/iothub/client_fixtures.py index 79f91a837..8358a0a59 100644 --- a/azure-iot-device/tests/iothub/client_fixtures.py +++ b/azure-iot-device/tests/iothub/client_fixtures.py @@ -164,7 +164,10 @@ def edge_local_debug_environment(): shared_access_key=shared_access_key, gateway_hostname=gateway_hostname, ) - return {"EdgeHubConnectionString": cs, "EdgeModuleCACertificateFile": "__FAKE_CA_CERTIFICATE__"} + return { + "EdgeHubConnectionString": cs, + "EdgeModuleCACertificateFile": "__FAKE_SERVER_VERIFICATION_CERTIFICATE__", + } """----Shared mock pipeline fixture----""" @@ -202,6 +205,22 @@ class FakeIoTHubPipeline: callback() +class FakeHTTPPipeline: + def __init__(self): + pass + + def invoke_method(self, device_id, method_params, callback, module_id=None): + callback(invoke_method_response="__fake_method_response__") + + def get_storage_info_for_blob(self, blob_name, callback): + callback(storage_info="__fake_storage_info__") + + def notify_blob_upload_status( + self, correlation_id, is_success, status_code, status_description, callback + ): + callback() + + @pytest.fixture def iothub_pipeline(mocker): """This fixture will automatically handle callbacks and should be @@ -219,8 +238,19 @@ def iothub_pipeline_manual_cb(mocker): @pytest.fixture -def edge_pipeline(mocker): - return mocker.MagicMock() # TODO: change this to wrap a pipeline object +def http_pipeline(mocker): + """This fixture will automatically handle callbacks and should be + used in the majority of tests + """ + return mocker.MagicMock(wraps=FakeHTTPPipeline()) + + +@pytest.fixture +def http_pipeline_manual_cb(mocker): + """This fixture is for use in tests where manual triggering of a + callback is required + """ + return mocker.MagicMock() @pytest.fixture @@ -228,6 +258,19 @@ def fake_twin(): return {"fake_twin": True} +"""----Shared symmetric key fixtures----""" + + @pytest.fixture -def fake_error(): - return RuntimeError("__fake_error__") +def symmetric_key(): + return shared_access_key + + +@pytest.fixture +def hostname_fixture(): + return hostname + + +@pytest.fixture +def device_id_fixture(): + return device_id diff --git a/azure-iot-device/tests/iothub/conftest.py b/azure-iot-device/tests/iothub/conftest.py index c350d072c..7b8777e5b 100644 --- a/azure-iot-device/tests/iothub/conftest.py +++ b/azure-iot-device/tests/iothub/conftest.py @@ -16,7 +16,8 @@ from .client_fixtures import ( twin_patch_reported, iothub_pipeline, iothub_pipeline_manual_cb, - edge_pipeline, + http_pipeline, + http_pipeline_manual_cb, device_connection_string, module_connection_string, device_sas_token_string, @@ -25,7 +26,9 @@ from .client_fixtures import ( edge_local_debug_environment, x509, fake_twin, - fake_error, + symmetric_key, + device_id_fixture, + hostname_fixture, ) collect_ignore = [] diff --git a/azure-iot-device/tests/iothub/pipeline/conftest.py b/azure-iot-device/tests/iothub/pipeline/conftest.py index 4168a3db8..f0fbc1a79 100644 --- a/azure-iot-device/tests/iothub/pipeline/conftest.py +++ b/azure-iot-device/tests/iothub/pipeline/conftest.py @@ -5,11 +5,9 @@ # -------------------------------------------------------------------------- from tests.common.pipeline.fixtures import ( - callback, - fake_exception, - fake_base_exception, - event, fake_pipeline_thread, fake_non_pipeline_thread, unhandled_error_handler, + arbitrary_op, + arbitrary_event, ) diff --git a/azure-iot-device/tests/iothub/pipeline/test_config.py b/azure-iot-device/tests/iothub/pipeline/test_config.py new file mode 100644 index 000000000..5957dd556 --- /dev/null +++ b/azure-iot-device/tests/iothub/pipeline/test_config.py @@ -0,0 +1,43 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import pytest +import logging +from tests.common.pipeline.pipeline_config_test import PipelineConfigInstantiationTestBase +from azure.iot.device.iothub.pipeline.config import IoTHubPipelineConfig + + +@pytest.mark.describe("IoTHubPipelineConfig - Instantiation") +class TestIoTHubPipelineConfigInstantiation(PipelineConfigInstantiationTestBase): + @pytest.fixture + def config_cls(self): + # This fixture is needed for the parent class + return IoTHubPipelineConfig + + @pytest.mark.it( + "Instantiates with the 'product_info' attribute set to the provided 'product_info' parameter" + ) + def test_product_info_set(self): + my_product_info = "some_info" + config = IoTHubPipelineConfig(product_info=my_product_info) + + assert config.product_info == my_product_info + + @pytest.mark.it( + "Instantiates with the 'product_info' attribute defaulting to empty string if there is no provided 'product_info'" + ) + def test_product_info_default(self): + config = IoTHubPipelineConfig() + assert config.product_info == "" + + @pytest.mark.it("Instantiates with the 'blob_upload' attribute set to False") + def test_blob_upload(self): + config = IoTHubPipelineConfig() + assert config.blob_upload is False + + @pytest.mark.it("Instantiates with the 'method_invoke' attribute set to False") + def test_method_invoke(self): + config = IoTHubPipelineConfig() + assert config.method_invoke is False diff --git a/azure-iot-device/tests/iothub/pipeline/test_http_path_iothub.py b/azure-iot-device/tests/iothub/pipeline/test_http_path_iothub.py new file mode 100644 index 000000000..25a2aa649 --- /dev/null +++ b/azure-iot-device/tests/iothub/pipeline/test_http_path_iothub.py @@ -0,0 +1,119 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import pytest +import logging +from azure.iot.device.iothub.pipeline import http_path_iothub + +logging.basicConfig(level=logging.DEBUG) + +# NOTE: All tests are parametrized with multiple values for URL encoding. This is to show that the +# URL encoding is done correctly - not all URL encoding encodes the '+' character. Thus we must +# make sure any URL encoded value can encode a '+' specifically, in addition to regular encoding. + + +@pytest.mark.describe(".get_method_invoke_path()") +class TestGetMethodInvokePath(object): + @pytest.mark.it("Returns the method invoke HTTP path") + @pytest.mark.parametrize( + "device_id, module_id, expected_path", + [ + pytest.param( + "my_device", + None, + "twins/my_device/methods", + id="'my_device' ==> 'twins/my_device/methods'", + ), + pytest.param( + "my/device", + None, + "twins/my%2Fdevice/methods", + id="'my/device' ==> 'twins/my%2Fdevice/methods'", + ), + pytest.param( + "my+device", + None, + "twins/my%2Bdevice/methods", + id="'my+device' ==> 'twins/my%2Bdevice/methods'", + ), + pytest.param( + "my_device", + "my_module", + "twins/my_device/modules/my_module/methods", + id="('my_device', 'my_module') ==> 'twins/my_device/modules/my_module/methods'", + ), + pytest.param( + "my/device", + "my?module", + "twins/my%2Fdevice/modules/my%3Fmodule/methods", + id="('my/device', 'my?module') ==> 'twins/my%2Fdevice/modules/my%3Fmodule/methods'", + ), + pytest.param( + "my+device", + "my+module", + "twins/my%2Bdevice/modules/my%2Bmodule/methods", + id="('my+device', 'my+module') ==> 'twins/my%2Bdevice/modules/my%2Bmodule/methods'", + ), + ], + ) + def test_path(self, device_id, module_id, expected_path): + path = http_path_iothub.get_method_invoke_path(device_id=device_id, module_id=module_id) + assert path == expected_path + + +@pytest.mark.describe(".get_storage_info_for_blob_path()") +class TestGetStorageInfoPath(object): + @pytest.mark.it("Returns the storage info HTTP path") + @pytest.mark.parametrize( + "device_id, expected_path", + [ + pytest.param( + "my_device", + "devices/my_device/files", + id="'my_device' ==> 'devices/my_device/files'", + ), + pytest.param( + "my/device", + "devices/my%2Fdevice/files", + id="'my/device' ==> 'devices/my%2Fdevice/files'", + ), + pytest.param( + "my+device", + "devices/my%2Bdevice/files", + id="'my+device' ==> 'devices/my%2Bdevice/files'", + ), + ], + ) + def test_path(self, device_id, expected_path): + path = http_path_iothub.get_storage_info_for_blob_path(device_id) + assert path == expected_path + + +@pytest.mark.describe(".get_notify_blob_upload_status_path()") +class TestGetNotifyBlobUploadStatusPath(object): + @pytest.mark.it("Returns the notify blob upload status HTTP path") + @pytest.mark.parametrize( + "device_id, expected_path", + [ + pytest.param( + "my_device", + "devices/my_device/files/notifications", + id="'my_device' ==> 'devices/my_device/files/notifications'", + ), + pytest.param( + "my/device", + "devices/my%2Fdevice/files/notifications", + id="'my/device' ==> 'devices/my%2Fdevice/files/notifications'", + ), + pytest.param( + "my+device", + "devices/my%2Bdevice/files/notifications", + id="'my+device' ==> 'devices/my%2Bdevice/files/notifications'", + ), + ], + ) + def test_path(self, device_id, expected_path): + path = http_path_iothub.get_notify_blob_upload_status_path(device_id) + assert path == expected_path diff --git a/azure-iot-device/tests/iothub/pipeline/test_http_pipeline.py b/azure-iot-device/tests/iothub/pipeline/test_http_pipeline.py new file mode 100644 index 000000000..8109798cc --- /dev/null +++ b/azure-iot-device/tests/iothub/pipeline/test_http_pipeline.py @@ -0,0 +1,402 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import pytest +import logging +import six.moves.urllib as urllib +from azure.iot.device.common import handle_exceptions +from azure.iot.device.common.pipeline import ( + pipeline_stages_base, + pipeline_stages_http, + pipeline_ops_base, +) +from azure.iot.device.iothub.pipeline import ( + pipeline_stages_iothub, + pipeline_stages_iothub_http, + pipeline_ops_iothub, + pipeline_ops_iothub_http, +) +from azure.iot.device.iothub.pipeline import HTTPPipeline, constant +from azure.iot.device.iothub.auth import ( + SymmetricKeyAuthenticationProvider, + X509AuthenticationProvider, +) + +logging.basicConfig(level=logging.DEBUG) +pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") + +fake_device_id = "__fake_device_id__" +fake_module_id = "__fake_module_id__" +fake_blob_name = "__fake_blob_name__" + + +@pytest.fixture +def auth_provider(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def pipeline_configuration(mocker): + mocked_configuration = mocker.MagicMock() + mocked_configuration.blob_upload = True + mocked_configuration.method_invoke = True + return mocked_configuration + + +@pytest.fixture +def pipeline(mocker, auth_provider, pipeline_configuration): + pipeline = HTTPPipeline(auth_provider, pipeline_configuration) + mocker.patch.object(pipeline._pipeline, "run_op") + return pipeline + + +@pytest.fixture +def twin_patch(): + return {"key": "value"} + + +# automatically mock the transport for all tests in this file. +@pytest.fixture(autouse=True) +def mock_transport(mocker): + print("mocking transport") + mocker.patch( + "azure.iot.device.common.pipeline.pipeline_stages_http.HTTPTransport", autospec=True + ) + + +@pytest.mark.describe("HTTPPipeline - Instantiation") +class TestHTTPPipelineInstantiation(object): + @pytest.mark.it("Configures the pipeline with a series of PipelineStages") + def test_pipeline_configuration(self, auth_provider, pipeline_configuration): + pipeline = HTTPPipeline(auth_provider, pipeline_configuration) + curr_stage = pipeline._pipeline + + expected_stage_order = [ + pipeline_stages_base.PipelineRootStage, + pipeline_stages_iothub.UseAuthProviderStage, + pipeline_stages_iothub_http.IoTHubHTTPTranslationStage, + pipeline_stages_http.HTTPTransportStage, + ] + + # Assert that all PipelineStages are there, and they are in the right order + for i in range(len(expected_stage_order)): + expected_stage = expected_stage_order[i] + assert isinstance(curr_stage, expected_stage) + curr_stage = curr_stage.next + + # Assert there are no more additional stages + assert curr_stage is None + + # TODO: revist these tests after auth revision + # They are too tied to auth types (and there's too much variance in auths to effectively test) + # Ideally HTTPPipeline is entirely insulated from any auth differential logic (and module/device distinctions) + # In the meantime, we are using a device auth with connection string to stand in for generic SAS auth + # and device auth with X509 certs to stand in for generic X509 auth + @pytest.mark.it( + "Runs a SetAuthProviderOperation with the provided AuthenticationProvider on the pipeline, if using SAS based authentication" + ) + def test_sas_auth(self, mocker, device_connection_string, pipeline_configuration): + mocker.spy(pipeline_stages_base.PipelineRootStage, "run_op") + auth_provider = SymmetricKeyAuthenticationProvider.parse(device_connection_string) + pipeline = HTTPPipeline(auth_provider, pipeline_configuration) + op = pipeline._pipeline.run_op.call_args[0][1] + assert pipeline._pipeline.run_op.call_count == 1 + assert isinstance(op, pipeline_ops_iothub.SetAuthProviderOperation) + assert op.auth_provider is auth_provider + + @pytest.mark.it( + "Raises exceptions that occurred in execution upon unsuccessful completion of the SetAuthProviderOperation" + ) + def test_sas_auth_op_fail( + self, mocker, device_connection_string, arbitrary_exception, pipeline_configuration + ): + old_run_op = pipeline_stages_base.PipelineRootStage._run_op + + def fail_set_auth_provider(self, op): + if isinstance(op, pipeline_ops_iothub.SetAuthProviderOperation): + op.complete(error=arbitrary_exception) + else: + old_run_op(self, op) + + mocker.patch.object( + pipeline_stages_base.PipelineRootStage, + "_run_op", + side_effect=fail_set_auth_provider, + autospec=True, + ) + + auth_provider = SymmetricKeyAuthenticationProvider.parse(device_connection_string) + with pytest.raises(arbitrary_exception.__class__) as e_info: + HTTPPipeline(auth_provider, pipeline_configuration) + assert e_info.value is arbitrary_exception + + @pytest.mark.it( + "Runs a SetX509AuthProviderOperation with the provided AuthenticationProvider on the pipeline, if using SAS based authentication" + ) + def test_cert_auth(self, mocker, x509, pipeline_configuration): + mocker.spy(pipeline_stages_base.PipelineRootStage, "run_op") + auth_provider = X509AuthenticationProvider( + hostname="somehostname", device_id=fake_device_id, x509=x509 + ) + pipeline = HTTPPipeline(auth_provider, pipeline_configuration) + op = pipeline._pipeline.run_op.call_args[0][1] + assert pipeline._pipeline.run_op.call_count == 1 + assert isinstance(op, pipeline_ops_iothub.SetX509AuthProviderOperation) + assert op.auth_provider is auth_provider + + @pytest.mark.it( + "Raises exceptions that occurred in execution upon unsuccessful completion of the SetX509AuthProviderOperation" + ) + def test_cert_auth_op_fail(self, mocker, x509, arbitrary_exception, pipeline_configuration): + old_run_op = pipeline_stages_base.PipelineRootStage._run_op + + def fail_set_auth_provider(self, op): + if isinstance(op, pipeline_ops_iothub.SetX509AuthProviderOperation): + op.complete(error=arbitrary_exception) + else: + old_run_op(self, op) + + mocker.patch.object( + pipeline_stages_base.PipelineRootStage, + "_run_op", + side_effect=fail_set_auth_provider, + autospec=True, + ) + + auth_provider = X509AuthenticationProvider( + hostname="somehostname", device_id=fake_device_id, x509=x509 + ) + with pytest.raises(arbitrary_exception.__class__): + HTTPPipeline(auth_provider, pipeline_configuration) + + +@pytest.mark.describe("HTTPPipeline - .invoke_method()") +class TestHTTPPipelineInvokeMethod(object): + @pytest.mark.it("Runs a MethodInvokeOperation on the pipeline") + def test_runs_op(self, pipeline, mocker): + cb = mocker.MagicMock() + pipeline.invoke_method( + device_id=fake_device_id, + module_id=fake_module_id, + method_params=mocker.MagicMock(), + callback=cb, + ) + assert pipeline._pipeline.run_op.call_count == 1 + assert isinstance( + pipeline._pipeline.run_op.call_args[0][0], + pipeline_ops_iothub_http.MethodInvokeOperation, + ) + + @pytest.mark.it( + "Calls the callback with the error if the pipeline_configuration.method_invoke is not True" + ) + def test_op_configuration_fail(self, mocker, pipeline, arbitrary_exception): + pipeline._pipeline.pipeline_configuration.method_invoke = False + cb = mocker.MagicMock() + + pipeline.invoke_method( + device_id=fake_device_id, + module_id=fake_module_id, + method_params=mocker.MagicMock(), + callback=cb, + ) + + assert cb.call_count == 1 + assert cb.call_args == mocker.call(error=mocker.ANY) + + @pytest.mark.it("Passes the correct parameters to the MethodInvokeOperation") + def test_passes_params_to_op(self, pipeline, mocker): + cb = mocker.MagicMock() + mocked_op = mocker.patch.object(pipeline_ops_iothub_http, "MethodInvokeOperation") + fake_method_params = mocker.MagicMock() + pipeline.invoke_method( + device_id=fake_device_id, + module_id=fake_module_id, + method_params=fake_method_params, + callback=cb, + ) + + assert mocked_op.call_args == mocker.call( + callback=mocker.ANY, + method_params=fake_method_params, + target_device_id=fake_device_id, + target_module_id=fake_module_id, + ) + + @pytest.mark.it("Triggers the callback upon successful completion of the MethodInvokeOperation") + def test_op_success_with_callback(self, mocker, pipeline): + cb = mocker.MagicMock() + + # Begin operation + pipeline.invoke_method( + device_id=fake_device_id, + module_id=fake_module_id, + method_params=mocker.MagicMock(), + callback=cb, + ) + assert cb.call_count == 0 + + # Trigger op completion + op = pipeline._pipeline.run_op.call_args[0][0] + op.method_response = "__fake_method_response__" + op.complete(error=None) + + assert cb.call_count == 1 + assert cb.call_args == mocker.call( + error=None, invoke_method_response="__fake_method_response__" + ) + + @pytest.mark.it( + "Calls the callback with the error upon unsuccessful completion of the MethodInvokeOperation" + ) + def test_op_fail(self, mocker, pipeline, arbitrary_exception): + cb = mocker.MagicMock() + + pipeline.invoke_method( + device_id=fake_device_id, + module_id=fake_module_id, + method_params=mocker.MagicMock(), + callback=cb, + ) + op = pipeline._pipeline.run_op.call_args[0][0] + + op.complete(error=arbitrary_exception) + assert cb.call_count == 1 + assert cb.call_args == mocker.call(error=arbitrary_exception, invoke_method_response=None) + + +@pytest.mark.describe("HTTPPipeline - .get_storage_info_for_blob()") +class TestHTTPPipelineGetStorageInfo(object): + @pytest.mark.it("Runs a GetStorageInfoOperation on the pipeline") + def test_runs_op(self, pipeline, mocker): + pipeline.get_storage_info_for_blob( + blob_name="__fake_blob_name__", callback=mocker.MagicMock() + ) + assert pipeline._pipeline.run_op.call_count == 1 + assert isinstance( + pipeline._pipeline.run_op.call_args[0][0], + pipeline_ops_iothub_http.GetStorageInfoOperation, + ) + + @pytest.mark.it( + "Calls the callback with the error upon unsuccessful completion of the GetStorageInfoOperation" + ) + def test_op_configuration_fail(self, mocker, pipeline): + pipeline._pipeline.pipeline_configuration.blob_upload = False + cb = mocker.MagicMock() + pipeline.get_storage_info_for_blob(blob_name="__fake_blob_name__", callback=cb) + + assert cb.call_count == 1 + assert cb.call_args == mocker.call(error=mocker.ANY) + + @pytest.mark.it( + "Triggers the callback upon successful completion of the GetStorageInfoOperation" + ) + def test_op_success_with_callback(self, mocker, pipeline): + cb = mocker.MagicMock() + + # Begin operation + pipeline.get_storage_info_for_blob(blob_name="__fake_blob_name__", callback=cb) + assert cb.call_count == 0 + + # Trigger op completion callback + op = pipeline._pipeline.run_op.call_args[0][0] + op.storage_info = "__fake_storage_info__" + op.complete(error=None) + + assert cb.call_count == 1 + assert cb.call_args == mocker.call(error=None, storage_info="__fake_storage_info__") + + @pytest.mark.it( + "Calls the callback with the error upon unsuccessful completion of the GetStorageInfoOperation" + ) + def test_op_fail(self, mocker, pipeline, arbitrary_exception): + cb = mocker.MagicMock() + pipeline.get_storage_info_for_blob(blob_name="__fake_blob_name__", callback=cb) + + op = pipeline._pipeline.run_op.call_args[0][0] + op.complete(error=arbitrary_exception) + + assert cb.call_count == 1 + assert cb.call_args == mocker.call(error=arbitrary_exception, storage_info=None) + + +@pytest.mark.describe("HTTPPipeline - .notify_blob_upload_status()") +class TestHTTPPipelineNotifyBlobUploadStatus(object): + @pytest.mark.it( + "Runs a NotifyBlobUploadStatusOperation with the provided parameters on the pipeline" + ) + def test_runs_op(self, pipeline, mocker): + pipeline.notify_blob_upload_status( + correlation_id="__fake_correlation_id__", + is_success="__fake_is_success__", + status_code="__fake_status_code__", + status_description="__fake_status_description__", + callback=mocker.MagicMock(), + ) + op = pipeline._pipeline.run_op.call_args[0][0] + + assert pipeline._pipeline.run_op.call_count == 1 + assert isinstance(op, pipeline_ops_iothub_http.NotifyBlobUploadStatusOperation) + + @pytest.mark.it( + "Calls the callback with the error if pipeline_configuration.blob_upload is not True" + ) + def test_op_configuration_fail(self, mocker, pipeline): + pipeline._pipeline.pipeline_configuration.blob_upload = False + cb = mocker.MagicMock() + pipeline.notify_blob_upload_status( + correlation_id="__fake_correlation_id__", + is_success="__fake_is_success__", + status_code="__fake_status_code__", + status_description="__fake_status_description__", + callback=cb, + ) + + assert cb.call_count == 1 + assert cb.call_args == mocker.call(error=mocker.ANY) + + @pytest.mark.it( + "Triggers the callback upon successful completion of the NotifyBlobUploadStatusOperation" + ) + def test_op_success_with_callback(self, mocker, pipeline): + cb = mocker.MagicMock() + + # Begin operation + pipeline.notify_blob_upload_status( + correlation_id="__fake_correlation_id__", + is_success="__fake_is_success__", + status_code="__fake_status_code__", + status_description="__fake_status_description__", + callback=cb, + ) + assert cb.call_count == 0 + + # Trigger op completion callback + op = pipeline._pipeline.run_op.call_args[0][0] + op.complete(error=None) + + assert cb.call_count == 1 + assert cb.call_args == mocker.call(error=None) + + @pytest.mark.it( + "Calls the callback with the error upon unsuccessful completion of the NotifyBlobUploadStatusOperation" + ) + def test_op_fail(self, mocker, pipeline, arbitrary_exception): + cb = mocker.MagicMock() + pipeline.notify_blob_upload_status( + correlation_id="__fake_correlation_id__", + is_success="__fake_is_success__", + status_code="__fake_status_code__", + status_description="__fake_status_description__", + callback=cb, + ) + + op = pipeline._pipeline.run_op.call_args[0][0] + op.complete(error=arbitrary_exception) + + assert cb.call_count == 1 + assert cb.call_args == mocker.call(error=arbitrary_exception) diff --git a/azure-iot-device/tests/iothub/pipeline/test_iothub_pipeline.py b/azure-iot-device/tests/iothub/pipeline/test_iothub_pipeline.py index d191d4f6d..bd0ead923 100644 --- a/azure-iot-device/tests/iothub/pipeline/test_iothub_pipeline.py +++ b/azure-iot-device/tests/iothub/pipeline/test_iothub_pipeline.py @@ -7,11 +7,11 @@ import pytest import logging import six.moves.urllib as urllib +from azure.iot.device.common import handle_exceptions from azure.iot.device.common.pipeline import ( pipeline_stages_base, pipeline_stages_mqtt, pipeline_ops_base, - operation_flow, ) from azure.iot.device.iothub.pipeline import ( pipeline_stages_iothub, @@ -27,6 +27,7 @@ from azure.iot.device.iothub.auth import ( ) logging.basicConfig(level=logging.DEBUG) +pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") # Update this list with features as they are added to the SDK all_features = [ @@ -44,8 +45,13 @@ def auth_provider(mocker): @pytest.fixture -def pipeline(mocker, auth_provider): - pipeline = IoTHubPipeline(auth_provider) +def pipeline_configuration(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def pipeline(mocker, auth_provider, pipeline_configuration): + pipeline = IoTHubPipeline(auth_provider, pipeline_configuration) mocker.patch.object(pipeline._pipeline, "run_op") return pipeline @@ -68,20 +74,20 @@ def mock_transport(mocker): class TestIoTHubPipelineInstantiation(object): @pytest.mark.it("Begins tracking the enabled/disabled status of features") @pytest.mark.parametrize("feature", all_features) - def test_features(self, auth_provider, feature): - pipeline = IoTHubPipeline(auth_provider) + def test_features(self, auth_provider, pipeline_configuration, feature): + pipeline = IoTHubPipeline(auth_provider, pipeline_configuration) pipeline.feature_enabled[feature] # No assertion required - if this doesn't raise a KeyError, it is a success @pytest.mark.it("Marks all features as disabled") - def test_features_disabled(self, auth_provider): - pipeline = IoTHubPipeline(auth_provider) + def test_features_disabled(self, auth_provider, pipeline_configuration): + pipeline = IoTHubPipeline(auth_provider, pipeline_configuration) for key in pipeline.feature_enabled: assert not pipeline.feature_enabled[key] @pytest.mark.it("Sets all handlers to an initial value of None") - def test_handlers_set_to_none(self, auth_provider): - pipeline = IoTHubPipeline(auth_provider) + def test_handlers_set_to_none(self, auth_provider, pipeline_configuration): + pipeline = IoTHubPipeline(auth_provider, pipeline_configuration) assert pipeline.on_connected is None assert pipeline.on_disconnected is None assert pipeline.on_c2d_message_received is None @@ -90,25 +96,28 @@ class TestIoTHubPipelineInstantiation(object): assert pipeline.on_twin_patch_received is None @pytest.mark.it("Configures the pipeline to trigger handlers in response to external events") - def test_handlers_configured(self, auth_provider): - pipeline = IoTHubPipeline(auth_provider) + def test_handlers_configured(self, auth_provider, pipeline_configuration): + pipeline = IoTHubPipeline(auth_provider, pipeline_configuration) assert pipeline._pipeline.on_pipeline_event_handler is not None assert pipeline._pipeline.on_connected_handler is not None assert pipeline._pipeline.on_disconnected_handler is not None @pytest.mark.it("Configures the pipeline with a series of PipelineStages") - def test_pipeline_configuration(self, auth_provider): - pipeline = IoTHubPipeline(auth_provider) + def test_pipeline_configuration(self, auth_provider, pipeline_configuration): + pipeline = IoTHubPipeline(auth_provider, pipeline_configuration) curr_stage = pipeline._pipeline expected_stage_order = [ pipeline_stages_base.PipelineRootStage, pipeline_stages_iothub.UseAuthProviderStage, - pipeline_stages_iothub.HandleTwinOperationsStage, + pipeline_stages_iothub.TwinRequestResponseStage, pipeline_stages_base.CoordinateRequestAndResponseStage, - pipeline_stages_iothub_mqtt.IoTHubMQTTConverterStage, - pipeline_stages_base.EnsureConnectionStage, - pipeline_stages_base.SerializeConnectOpsStage, + pipeline_stages_iothub_mqtt.IoTHubMQTTTranslationStage, + pipeline_stages_base.AutoConnectStage, + pipeline_stages_base.ReconnectStage, + pipeline_stages_base.ConnectionLockStage, + pipeline_stages_base.RetryStage, + pipeline_stages_base.OpTimeoutStage, pipeline_stages_mqtt.MQTTTransportStage, ] @@ -129,76 +138,79 @@ class TestIoTHubPipelineInstantiation(object): @pytest.mark.it( "Runs a SetAuthProviderOperation with the provided AuthenticationProvider on the pipeline, if using SAS based authentication" ) - def test_sas_auth(self, mocker, device_connection_string): + def test_sas_auth(self, mocker, device_connection_string, pipeline_configuration): mocker.spy(pipeline_stages_base.PipelineRootStage, "run_op") auth_provider = SymmetricKeyAuthenticationProvider.parse(device_connection_string) - pipeline = IoTHubPipeline(auth_provider) + pipeline = IoTHubPipeline(auth_provider, pipeline_configuration) op = pipeline._pipeline.run_op.call_args[0][1] assert pipeline._pipeline.run_op.call_count == 1 assert isinstance(op, pipeline_ops_iothub.SetAuthProviderOperation) assert op.auth_provider is auth_provider @pytest.mark.it( - "Propagates exceptions that occurred in execution upon unsuccessful completion of the SetAuthProviderOperation" + "Raises exceptions that occurred in execution upon unsuccessful completion of the SetAuthProviderOperation" ) - def test_sas_auth_op_fail(self, mocker, device_connection_string, fake_exception): - old_execute_op = pipeline_stages_base.PipelineRootStage._execute_op + def test_sas_auth_op_fail( + self, mocker, device_connection_string, arbitrary_exception, pipeline_configuration + ): + old_run_op = pipeline_stages_base.PipelineRootStage._run_op def fail_set_auth_provider(self, op): - if isinstance(op, pipeline_stages_base.SetAuthProviderOperation): - op.error = fake_exception - operation_flow.complete_op(stage=self, op=op) + if isinstance(op, pipeline_ops_iothub.SetAuthProviderOperation): + op.complete(error=arbitrary_exception) else: - old_execute_op(self, op) + old_run_op(self, op) mocker.patch.object( pipeline_stages_base.PipelineRootStage, - "_execute_op", + "_run_op", side_effect=fail_set_auth_provider, + autospec=True, ) auth_provider = SymmetricKeyAuthenticationProvider.parse(device_connection_string) - with pytest.raises(fake_exception.__class__): - IoTHubPipeline(auth_provider) + with pytest.raises(arbitrary_exception.__class__) as e_info: + IoTHubPipeline(auth_provider, pipeline_configuration) + assert e_info.value is arbitrary_exception @pytest.mark.it( "Runs a SetX509AuthProviderOperation with the provided AuthenticationProvider on the pipeline, if using SAS based authentication" ) - def test_cert_auth(self, mocker, x509): + def test_cert_auth(self, mocker, x509, pipeline_configuration): mocker.spy(pipeline_stages_base.PipelineRootStage, "run_op") auth_provider = X509AuthenticationProvider( hostname="somehostname", device_id="somedevice", x509=x509 ) - pipeline = IoTHubPipeline(auth_provider) + pipeline = IoTHubPipeline(auth_provider, pipeline_configuration) op = pipeline._pipeline.run_op.call_args[0][1] assert pipeline._pipeline.run_op.call_count == 1 assert isinstance(op, pipeline_ops_iothub.SetX509AuthProviderOperation) assert op.auth_provider is auth_provider @pytest.mark.it( - "Propagates exceptions that occurred in execution upon unsuccessful completion of the SetX509AuthProviderOperation" + "Raises exceptions that occurred in execution upon unsuccessful completion of the SetX509AuthProviderOperation" ) - def test_cert_auth_op_fail(self, mocker, x509, fake_exception): - old_execute_op = pipeline_stages_base.PipelineRootStage._execute_op + def test_cert_auth_op_fail(self, mocker, x509, arbitrary_exception, pipeline_configuration): + old_run_op = pipeline_stages_base.PipelineRootStage._run_op def fail_set_auth_provider(self, op): - if isinstance(op, pipeline_stages_base.SetX509AuthProviderOperation): - op.error = fake_exception - operation_flow.complete_op(stage=self, op=op) + if isinstance(op, pipeline_ops_iothub.SetX509AuthProviderOperation): + op.complete(error=arbitrary_exception) else: - old_execute_op(self, op) + old_run_op(self, op) mocker.patch.object( pipeline_stages_base.PipelineRootStage, - "_execute_op", + "_run_op", side_effect=fail_set_auth_provider, + autospec=True, ) auth_provider = X509AuthenticationProvider( hostname="somehostname", device_id="somedevice", x509=x509 ) - with pytest.raises(fake_exception.__class__): - IoTHubPipeline(auth_provider) + with pytest.raises(arbitrary_exception.__class__): + IoTHubPipeline(auth_provider, pipeline_configuration) @pytest.mark.describe("IoTHubPipeline - .connect()") @@ -220,26 +232,25 @@ class TestIoTHubPipelineConnect(object): pipeline.connect(callback=cb) assert cb.call_count == 0 - # Trigger op completion callback + # Trigger op completion op = pipeline._pipeline.run_op.call_args[0][0] - op.callback(op) + op.complete(error=None) assert cb.call_count == 1 - assert cb.call_args == mocker.call() + assert cb.call_args == mocker.call(error=None) @pytest.mark.it( "Calls the callback with the error upon unsuccessful completion of the ConnectOperation" ) - def test_op_fail(self, mocker, pipeline): + def test_op_fail(self, mocker, pipeline, arbitrary_exception): cb = mocker.MagicMock() pipeline.connect(callback=cb) op = pipeline._pipeline.run_op.call_args[0][0] - op.error = Exception() - op.callback(op) + op.complete(error=arbitrary_exception) assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=op.error) + assert cb.call_args == mocker.call(error=arbitrary_exception) @pytest.mark.describe("IoTHubPipeline - .disconnect()") @@ -262,24 +273,23 @@ class TestIoTHubPipelineDisconnect(object): # Trigger op completion callback op = pipeline._pipeline.run_op.call_args[0][0] - op.callback(op) + op.complete(error=None) assert cb.call_count == 1 - assert cb.call_args == mocker.call() + assert cb.call_args == mocker.call(error=None) @pytest.mark.it( "Calls the callback with the error upon unsuccessful completion of the DisconnectOperation" ) - def test_op_fail(self, mocker, pipeline): + def test_op_fail(self, mocker, pipeline, arbitrary_exception): cb = mocker.MagicMock() pipeline.disconnect(callback=cb) op = pipeline._pipeline.run_op.call_args[0][0] - op.error = Exception() - op.callback(op) + op.complete(error=arbitrary_exception) assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=op.error) + assert cb.call_args == mocker.call(error=arbitrary_exception) @pytest.mark.describe("IoTHubPipeline - .send_message()") @@ -305,24 +315,23 @@ class TestIoTHubPipelineSendD2CMessage(object): # Trigger op completion callback op = pipeline._pipeline.run_op.call_args[0][0] - op.callback(op) + op.complete(error=None) assert cb.call_count == 1 - assert cb.call_args == mocker.call() + assert cb.call_args == mocker.call(error=None) @pytest.mark.it( "Calls the callback with the error upon unsuccessful completion of the SendD2CMessageOperation" ) - def test_op_fail(self, mocker, pipeline, message): + def test_op_fail(self, mocker, pipeline, message, arbitrary_exception): cb = mocker.MagicMock() pipeline.send_message(message, callback=cb) op = pipeline._pipeline.run_op.call_args[0][0] - op.error = Exception() - op.callback(op) + op.complete(error=arbitrary_exception) assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=op.error) + assert cb.call_args == mocker.call(error=arbitrary_exception) @pytest.mark.describe("IoTHubPipeline - .send_output_event()") @@ -354,24 +363,23 @@ class TestIoTHubPipelineSendOutputEvent(object): # Trigger op completion callback op = pipeline._pipeline.run_op.call_args[0][0] - op.callback(op) + op.complete(error=None) assert cb.call_count == 1 - assert cb.call_args == mocker.call() + assert cb.call_args == mocker.call(error=None) @pytest.mark.it( "Calls the callback with the error upon unsuccessful completion of the SendOutputEventOperation" ) - def test_op_fail(self, mocker, pipeline, message): + def test_op_fail(self, mocker, pipeline, message, arbitrary_exception): cb = mocker.MagicMock() pipeline.send_output_event(message, callback=cb) op = pipeline._pipeline.run_op.call_args[0][0] - op.error = Exception() - op.callback(op) + op.complete(error=arbitrary_exception) assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=op.error) + assert cb.call_args == mocker.call(error=arbitrary_exception) @pytest.mark.describe("IoTHubPipeline - .send_method_response()") @@ -399,24 +407,23 @@ class TestIoTHubPipelineSendMethodResponse(object): # Trigger op completion callback op = pipeline._pipeline.run_op.call_args[0][0] - op.callback(op) + op.complete(error=None) assert cb.call_count == 1 - assert cb.call_args == mocker.call() + assert cb.call_args == mocker.call(error=None) @pytest.mark.it( "Calls the callback with the error upon unsuccessful completion of the SendMethodResponseOperation" ) - def test_op_fail(self, mocker, pipeline, method_response): + def test_op_fail(self, mocker, pipeline, method_response, arbitrary_exception): cb = mocker.MagicMock() pipeline.send_method_response(method_response, callback=cb) op = pipeline._pipeline.run_op.call_args[0][0] - op.error = Exception() - op.callback(op) + op.complete(error=arbitrary_exception) assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=op.error) + assert cb.call_args == mocker.call(error=arbitrary_exception) @pytest.mark.describe("IoTHubPipeline - .get_twin()") @@ -442,7 +449,7 @@ class TestIoTHubPipelineGetTwin(object): # Trigger op completion callback op = pipeline._pipeline.run_op.call_args[0][0] - op.callback(op) + op.complete(error=None) assert cb.call_count == 1 assert cb.call_args == mocker.call(twin=None) @@ -450,16 +457,15 @@ class TestIoTHubPipelineGetTwin(object): @pytest.mark.it( "Calls the callback with the error upon unsuccessful completion of the GetTwinOperation" ) - def test_op_fail(self, mocker, pipeline): + def test_op_fail(self, mocker, pipeline, arbitrary_exception): cb = mocker.MagicMock() pipeline.get_twin(callback=cb) op = pipeline._pipeline.run_op.call_args[0][0] - op.error = Exception() - op.callback(op) + op.complete(error=arbitrary_exception) assert cb.call_count == 1 - assert cb.call_args == mocker.call(twin=None, error=op.error) + assert cb.call_args == mocker.call(twin=None, error=arbitrary_exception) @pytest.mark.describe("IoTHubPipeline - .patch_twin_reported_properties()") @@ -487,24 +493,23 @@ class TestIoTHubPipelinePatchTwinReportedProperties(object): # Trigger op completion callback op = pipeline._pipeline.run_op.call_args[0][0] - op.callback(op) + op.complete(error=None) assert cb.call_count == 1 - assert cb.call_args == mocker.call() + assert cb.call_args == mocker.call(error=None) @pytest.mark.it( "Calls the callback with the error upon unsuccessful completion of the PatchTwinReportedPropertiesOperation" ) - def test_op_fail(self, mocker, pipeline, twin_patch): + def test_op_fail(self, mocker, pipeline, twin_patch, arbitrary_exception): cb = mocker.MagicMock() pipeline.patch_twin_reported_properties(twin_patch, callback=cb) op = pipeline._pipeline.run_op.call_args[0][0] - op.error = Exception() - op.callback(op) + op.complete(error=arbitrary_exception) assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=op.error) + assert cb.call_args == mocker.call(error=arbitrary_exception) @pytest.mark.describe("IoTHubPipeline - .enable_feature()") @@ -549,25 +554,24 @@ class TestIoTHubPipelineEnableFeature(object): # Trigger op completion callback op = pipeline._pipeline.run_op.call_args[0][0] - op.callback(op) + op.complete(error=None) assert cb.call_count == 1 - assert cb.call_args == mocker.call() + assert cb.call_args == mocker.call(error=None) @pytest.mark.it( "Calls the callback with the error upon unsuccessful completion of the EnableFeatureOperation" ) @pytest.mark.parametrize("feature", all_features) - def test_op_fail(self, mocker, pipeline, feature): + def test_op_fail(self, mocker, pipeline, feature, arbitrary_exception): cb = mocker.MagicMock() pipeline.enable_feature(feature, callback=cb) op = pipeline._pipeline.run_op.call_args[0][0] - op.error = Exception() - op.callback(op) + op.complete(error=arbitrary_exception) assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=op.error) + assert cb.call_args == mocker.call(error=arbitrary_exception) @pytest.mark.describe("IoTHubPipeline - .disable_feature()") @@ -614,28 +618,27 @@ class TestIoTHubPipelineDisableFeature(object): # Trigger op completion callback op = pipeline._pipeline.run_op.call_args[0][0] - op.callback(op) + op.complete(error=None) assert cb.call_count == 1 - assert cb.call_args == mocker.call() + assert cb.call_args == mocker.call(error=None) @pytest.mark.it( "Calls the callback with the error upon unsuccessful completion of the DisableFeatureOperation" ) @pytest.mark.parametrize("feature", all_features) - def _est_op_fail(self, mocker, pipeline, feature): + def _est_op_fail(self, mocker, pipeline, feature, arbitrary_exception): cb = mocker.MagicMock() pipeline.disable_feature(feature, callback=cb) op = pipeline._pipeline.run_op.call_args[0][0] - op.error = Exception() - op.callback(op) + op.complete(error=arbitrary_exception) assert cb.call_count == 1 - assert cb.call_args == mocker.call(error=op.error) + assert cb.call_args == mocker.call(error=arbitrary_exception) -@pytest.mark.describe("IoTHubPipeline - EVENT: Connected") +@pytest.mark.describe("IoTHubPipeline - OCCURANCE: Connected") class TestIoTHubPipelineEVENTConnect(object): @pytest.mark.it("Triggers the 'on_connected' handler") def test_with_handler(self, mocker, pipeline): @@ -657,7 +660,7 @@ class TestIoTHubPipelineEVENTConnect(object): # No assertions required - not throwing an exception means the test passed -@pytest.mark.describe("IoTHubPipeline - EVENT: Disconnected") +@pytest.mark.describe("IoTHubPipeline - OCCURANCE: Disconnected") class TestIoTHubPipelineEVENTDisconnect(object): @pytest.mark.it("Triggers the 'on_disconnected' handler") def test_with_handler(self, mocker, pipeline): @@ -679,7 +682,7 @@ class TestIoTHubPipelineEVENTDisconnect(object): # No assertions required - not throwing an exception means the test passed -@pytest.mark.describe("IoTHubPipeline - EVENT: C2D Message Received") +@pytest.mark.describe("IoTHubPipeline - OCCURANCE: C2D Message Received") class TestIoTHubPipelineEVENTRecieveC2DMessage(object): @pytest.mark.it( "Triggers the 'on_c2d_message_received' handler, passing the received message as an argument" @@ -707,7 +710,7 @@ class TestIoTHubPipelineEVENTRecieveC2DMessage(object): # No assertions required - not throwing an exception means the test passed -@pytest.mark.describe("IoTHubPipeline - EVENT: Input Message Received") +@pytest.mark.describe("IoTHubPipeline - OCCURANCE: Input Message Received") class TestIoTHubPipelineEVENTReceiveInputMessage(object): @pytest.mark.it( "Triggers the 'on_input_message_received' handler, passing the received message and input name as arguments" @@ -737,7 +740,7 @@ class TestIoTHubPipelineEVENTReceiveInputMessage(object): # No assertions required - not throwing an exception means the test passed -@pytest.mark.describe("IoTHubPipeline - EVENT: Method Request Received") +@pytest.mark.describe("IoTHubPipeline - OCCURANCE: Method Request Received") class TestIoTHubPipelineEVENTReceiveMethodRequest(object): @pytest.mark.it( "Triggers the 'on_method_request_received' handler, passing the received method request as an argument" @@ -767,7 +770,7 @@ class TestIoTHubPipelineEVENTReceiveMethodRequest(object): # No assertions required - not throwing an exception means the test passed -@pytest.mark.describe("IoTHubPipeline - EVENT: Twin Desired Properties Patch Received") +@pytest.mark.describe("IoTHubPipeline - OCCURANCE: Twin Desired Properties Patch Received") class TestIoTHubPipelineEVENTReceiveDesiredPropertiesPatch(object): @pytest.mark.it( "Triggers the 'on_twin_patch_received' handler, passing the received twin patch as an argument" @@ -793,3 +796,18 @@ class TestIoTHubPipelineEVENTReceiveDesiredPropertiesPatch(object): pipeline._pipeline.on_pipeline_event_handler(twin_patch_event) # No assertions required - not throwing an exception means the test passed + + +@pytest.mark.describe("IoTHubPipeline - PROPERTY .connected") +class TestIotHubPipelinePROPERTYConnected(object): + @pytest.mark.it("Cannot be changed") + def test_read_only(self, pipeline): + with pytest.raises(AttributeError): + pipeline.connected = not pipeline.connected + + @pytest.mark.it("Reflects the value of the root stage property of the same name") + def test_reflects_pipeline_property(self, pipeline): + pipeline._pipeline.connected = True + assert pipeline.connected + pipeline._pipeline.connected = False + assert not pipeline.connected diff --git a/azure-iot-device/tests/iothub/pipeline/test_pipeline_events_iothub.py b/azure-iot-device/tests/iothub/pipeline/test_pipeline_events_iothub.py index 09f03757b..c54e75c27 100644 --- a/azure-iot-device/tests/iothub/pipeline/test_pipeline_events_iothub.py +++ b/azure-iot-device/tests/iothub/pipeline/test_pipeline_events_iothub.py @@ -6,30 +6,30 @@ import sys import logging from azure.iot.device.iothub.pipeline import pipeline_events_iothub -from tests.common.pipeline import pipeline_data_object_test +from tests.common.pipeline import pipeline_event_test logging.basicConfig(level=logging.DEBUG) this_module = sys.modules[__name__] -pipeline_data_object_test.add_event_test( +pipeline_event_test.add_event_test( cls=pipeline_events_iothub.C2DMessageEvent, module=this_module, positional_arguments=["message"], keyword_arguments={}, ) -pipeline_data_object_test.add_event_test( +pipeline_event_test.add_event_test( cls=pipeline_events_iothub.InputMessageEvent, module=this_module, positional_arguments=["input_name", "message"], keyword_arguments={}, ) -pipeline_data_object_test.add_event_test( +pipeline_event_test.add_event_test( cls=pipeline_events_iothub.MethodRequestEvent, module=this_module, positional_arguments=["method_request"], keyword_arguments={}, ) -pipeline_data_object_test.add_event_test( +pipeline_event_test.add_event_test( cls=pipeline_events_iothub.TwinDesiredPropertiesPatchEvent, module=this_module, positional_arguments=["patch"], diff --git a/azure-iot-device/tests/iothub/pipeline/test_pipeline_ops_iothub.py b/azure-iot-device/tests/iothub/pipeline/test_pipeline_ops_iothub.py index d4891b5d7..ecb39f6f3 100644 --- a/azure-iot-device/tests/iothub/pipeline/test_pipeline_ops_iothub.py +++ b/azure-iot-device/tests/iothub/pipeline/test_pipeline_ops_iothub.py @@ -3,66 +3,313 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +import pytest import sys import logging from azure.iot.device.iothub.pipeline import pipeline_ops_iothub -from tests.common.pipeline import pipeline_data_object_test +from tests.common.pipeline import pipeline_ops_test logging.basicConfig(level=logging.DEBUG) this_module = sys.modules[__name__] +pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") -pipeline_data_object_test.add_operation_test( - cls=pipeline_ops_iothub.SetAuthProviderOperation, - module=this_module, - positional_arguments=["auth_provider"], - keyword_arguments={"callback": None}, + +class SetAuthProviderOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_iothub.SetAuthProviderOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = {"auth_provider": mocker.MagicMock(), "callback": mocker.MagicMock()} + return kwargs + + +class SetAuthProviderOperationInstantiationTests(SetAuthProviderOperationTestConfig): + @pytest.mark.it( + "Initializes 'auth_provider' attribute with the provided 'auth_provider' parameter" + ) + def test_auth_provider(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.auth_provider is init_kwargs["auth_provider"] + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_iothub.SetAuthProviderOperation, + op_test_config_class=SetAuthProviderOperationTestConfig, + extended_op_instantiation_test_class=SetAuthProviderOperationInstantiationTests, ) -pipeline_data_object_test.add_operation_test( - cls=pipeline_ops_iothub.SetX509AuthProviderOperation, - module=this_module, - positional_arguments=["auth_provider"], - keyword_arguments={"callback": None}, + + +class SetX509AuthProviderOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_iothub.SetX509AuthProviderOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = {"auth_provider": mocker.MagicMock(), "callback": mocker.MagicMock()} + return kwargs + + +class SetX509AuthProviderOperationInstantiationTests(SetX509AuthProviderOperationTestConfig): + @pytest.mark.it( + "Initializes 'auth_provider' attribute with the provided 'auth_provider' parameter" + ) + def test_auth_provider(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.auth_provider is init_kwargs["auth_provider"] + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_iothub.SetX509AuthProviderOperation, + op_test_config_class=SetX509AuthProviderOperationTestConfig, + extended_op_instantiation_test_class=SetX509AuthProviderOperationInstantiationTests, ) -pipeline_data_object_test.add_operation_test( - cls=pipeline_ops_iothub.SetIoTHubConnectionArgsOperation, - module=this_module, - positional_arguments=["device_id", "hostname"], - keyword_arguments={ - "module_id": None, - "gateway_hostname": None, - "ca_cert": None, - "client_cert": None, - "sas_token": None, - "callback": None, - }, + + +class SetIoTHubConnectionArgsOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_iothub.SetIoTHubConnectionArgsOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = { + "device_id": "some_device_id", + "hostname": "some_hostname", + "callback": mocker.MagicMock(), + "module_id": "some_module_id", + "gateway_hostname": "some_gateway_hostname", + "server_verification_cert": "some_server_verification_cert", + "client_cert": "some_client_cert", + "sas_token": "some_sas_token", + } + return kwargs + + +class SetIoTHubConnectionArgsOperationInstantiationTests( + SetIoTHubConnectionArgsOperationTestConfig +): + @pytest.mark.it("Initializes 'device_id' attribute with the provided 'device_id' parameter") + def test_device_id(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.device_id == init_kwargs["device_id"] + + @pytest.mark.it("Initializes 'hostname' attribute with the provided 'hostname' parameter") + def test_hostname(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.hostname == init_kwargs["hostname"] + + @pytest.mark.it("Initializes 'module_id' attribute with the provided 'module_id' parameter") + def test_module_id(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.module_id == init_kwargs["module_id"] + + @pytest.mark.it( + "Initializes 'module_id' attribute to None if no 'module_id' parameter is provided" + ) + def test_module_id_default(self, cls_type, init_kwargs): + del init_kwargs["module_id"] + op = cls_type(**init_kwargs) + assert op.module_id is None + + @pytest.mark.it( + "Initializes 'gateway_hostname' attribute with the provided 'gateway_hostname' parameter" + ) + def test_gateway_hostname(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.gateway_hostname == init_kwargs["gateway_hostname"] + + @pytest.mark.it( + "Initializes 'gateway_hostname' attribute to None if no 'gateway_hostname' parameter is provided" + ) + def test_gateway_hostname_default(self, cls_type, init_kwargs): + del init_kwargs["gateway_hostname"] + op = cls_type(**init_kwargs) + assert op.gateway_hostname is None + + @pytest.mark.it( + "Initializes 'server_verification_cert' attribute with the provided 'server_verification_cert' parameter" + ) + def test_server_verification_cert(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.server_verification_cert == init_kwargs["server_verification_cert"] + + @pytest.mark.it( + "Initializes 'server_verification_cert' attribute to None if no 'server_verification_cert' parameter is provided" + ) + def test_server_verification_cert_default(self, cls_type, init_kwargs): + del init_kwargs["server_verification_cert"] + op = cls_type(**init_kwargs) + assert op.server_verification_cert is None + + @pytest.mark.it("Initializes 'client_cert' attribute with the provided 'client_cert' parameter") + def test_client_cert(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.client_cert == init_kwargs["client_cert"] + + @pytest.mark.it( + "Initializes 'client_cert' attribute to None if no 'client_cert' parameter is provided" + ) + def test_client_cert_default(self, cls_type, init_kwargs): + del init_kwargs["client_cert"] + op = cls_type(**init_kwargs) + assert op.client_cert is None + + @pytest.mark.it("Initializes 'sas_token' attribute with the provided 'sas_token' parameter") + def test_sas_token(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.sas_token == init_kwargs["sas_token"] + + @pytest.mark.it( + "Initializes 'sas_token' attribute to None if no 'sas_token' parameter is provided" + ) + def test_sas_token_default(self, cls_type, init_kwargs): + del init_kwargs["sas_token"] + op = cls_type(**init_kwargs) + assert op.sas_token is None + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_iothub.SetIoTHubConnectionArgsOperation, + op_test_config_class=SetIoTHubConnectionArgsOperationTestConfig, + extended_op_instantiation_test_class=SetIoTHubConnectionArgsOperationInstantiationTests, ) -pipeline_data_object_test.add_operation_test( - cls=pipeline_ops_iothub.SendD2CMessageOperation, - module=this_module, - positional_arguments=["message"], - keyword_arguments={"callback": None}, + + +class SendD2CMessageOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_iothub.SendD2CMessageOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = {"message": mocker.MagicMock(), "callback": mocker.MagicMock()} + return kwargs + + +class SendD2CMessageOperationInstantiationTests(SendD2CMessageOperationTestConfig): + @pytest.mark.it("Initializes 'message' attribute with the provided 'message' parameter") + def test_message(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.message is init_kwargs["message"] + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_iothub.SendD2CMessageOperation, + op_test_config_class=SendD2CMessageOperationTestConfig, + extended_op_instantiation_test_class=SendD2CMessageOperationInstantiationTests, ) -pipeline_data_object_test.add_operation_test( - cls=pipeline_ops_iothub.SendOutputEventOperation, - module=this_module, - positional_arguments=["message"], - keyword_arguments={"callback": None}, + + +class SendOutputEventOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_iothub.SendOutputEventOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = {"message": mocker.MagicMock(), "callback": mocker.MagicMock()} + return kwargs + + +class SendOutputEventOperationInstantiationTests(SendOutputEventOperationTestConfig): + @pytest.mark.it("Initializes 'message' attribute with the provided 'message' parameter") + def test_message(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.message is init_kwargs["message"] + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_iothub.SendOutputEventOperation, + op_test_config_class=SendOutputEventOperationTestConfig, + extended_op_instantiation_test_class=SendOutputEventOperationInstantiationTests, ) -pipeline_data_object_test.add_operation_test( - cls=pipeline_ops_iothub.SendMethodResponseOperation, - module=this_module, - positional_arguments=["method_response"], - keyword_arguments={"callback": None}, + + +class SendMethodResponseOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_iothub.SendMethodResponseOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = {"method_response": mocker.MagicMock(), "callback": mocker.MagicMock()} + return kwargs + + +class SendMethodResponseOperationInstantiationTests(SendMethodResponseOperationTestConfig): + @pytest.mark.it( + "Initializes 'method_response' attribute with the provided 'method_response' parameter" + ) + def test_method_response(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.method_response is init_kwargs["method_response"] + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_iothub.SendMethodResponseOperation, + op_test_config_class=SendMethodResponseOperationTestConfig, + extended_op_instantiation_test_class=SendMethodResponseOperationInstantiationTests, ) -pipeline_data_object_test.add_operation_test( - cls=pipeline_ops_iothub.GetTwinOperation, - module=this_module, - positional_arguments=[], - keyword_arguments={"callback": None}, + + +class GetTwinOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_iothub.GetTwinOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = {"callback": mocker.MagicMock()} + return kwargs + + +class GetTwinOperationInstantiationTests(GetTwinOperationTestConfig): + @pytest.mark.it("Initializes 'twin' attribute as None") + def test_twin(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.twin is None + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_iothub.GetTwinOperation, + op_test_config_class=GetTwinOperationTestConfig, + extended_op_instantiation_test_class=GetTwinOperationInstantiationTests, ) -pipeline_data_object_test.add_operation_test( - cls=pipeline_ops_iothub.PatchTwinReportedPropertiesOperation, - module=this_module, - positional_arguments=["patch"], - keyword_arguments={"callback": None}, + + +class PatchTwinReportedPropertiesOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_iothub.PatchTwinReportedPropertiesOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = {"patch": {"some": "patch"}, "callback": mocker.MagicMock()} + return kwargs + + +class PatchTwinReportedPropertiesOperationInstantiationTests( + PatchTwinReportedPropertiesOperationTestConfig +): + @pytest.mark.it("Initializes 'patch' attribute with the provided 'patch' parameter") + def test_patch(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.patch is init_kwargs["patch"] + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_iothub.PatchTwinReportedPropertiesOperation, + op_test_config_class=PatchTwinReportedPropertiesOperationTestConfig, + extended_op_instantiation_test_class=PatchTwinReportedPropertiesOperationInstantiationTests, ) diff --git a/azure-iot-device/tests/iothub/pipeline/test_pipeline_ops_iothub_http.py b/azure-iot-device/tests/iothub/pipeline/test_pipeline_ops_iothub_http.py new file mode 100644 index 000000000..236e5cd7c --- /dev/null +++ b/azure-iot-device/tests/iothub/pipeline/test_pipeline_ops_iothub_http.py @@ -0,0 +1,149 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import pytest +import sys +import logging +from azure.iot.device.iothub.pipeline import pipeline_ops_iothub_http +from tests.common.pipeline import pipeline_ops_test + +logging.basicConfig(level=logging.DEBUG) +this_module = sys.modules[__name__] +pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") + +fake_device_id = "__fake_device_id__" +fake_module_id = "__fake_module_id__" + + +class MethodInvokeOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_iothub_http.MethodInvokeOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = { + "target_device_id": fake_device_id, + "target_module_id": fake_module_id, + "method_params": mocker.MagicMock(), + "callback": mocker.MagicMock(), + } + return kwargs + + +class MethodInvokeOperationInstantiationTests(MethodInvokeOperationTestConfig): + @pytest.mark.it("Initializes 'device_id' attribute with the provided 'device_id' parameter") + def test_device_id(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.target_device_id is init_kwargs["target_device_id"] + + @pytest.mark.it("Initializes 'module_id' attribute with the provided 'module_id' parameter") + def test_module_id(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.target_module_id is init_kwargs["target_module_id"] + + @pytest.mark.it( + "Initializes 'method_params' attribute with the provided 'method_params' parameter" + ) + def test_method_params(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.method_params is init_kwargs["method_params"] + + @pytest.mark.it("Initializes 'method_response' attribute as None") + def test_method_response(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.method_response is None + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_iothub_http.MethodInvokeOperation, + op_test_config_class=MethodInvokeOperationTestConfig, + extended_op_instantiation_test_class=MethodInvokeOperationInstantiationTests, +) + + +class GetStorageInfoOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_iothub_http.GetStorageInfoOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = {"blob_name": "__fake_blob_name__", "callback": mocker.MagicMock()} + return kwargs + + +class GetStorageInfoOperationInstantiationTests(GetStorageInfoOperationTestConfig): + @pytest.mark.it("Initializes 'blob_name' attribute with the provided 'blob_name' parameter") + def test_blob_name(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.blob_name is init_kwargs["blob_name"] + + @pytest.mark.it("Initializes 'storage_info' attribute as None") + def test_storage_info(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.storage_info is None + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_iothub_http.GetStorageInfoOperation, + op_test_config_class=GetStorageInfoOperationTestConfig, + extended_op_instantiation_test_class=GetStorageInfoOperationInstantiationTests, +) + + +class NotifyBlobUploadStatusOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_iothub_http.NotifyBlobUploadStatusOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = { + "correlation_id": "__fake_correlation_id__", + "is_success": "__fake_is_success__", + "status_code": "__fake_status_code__", + "status_description": "__fake_status_description__", + "callback": mocker.MagicMock(), + } + return kwargs + + +class NotifyBlobUploadStatusOperationInstantiationTests(NotifyBlobUploadStatusOperationTestConfig): + @pytest.mark.it( + "Initializes 'correlation_id' attribute with the provided 'correlation_id' parameter" + ) + def test_correlation_id(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.correlation_id is init_kwargs["correlation_id"] + + @pytest.mark.it("Initializes 'is_success' attribute with the provided 'is_success' parameter") + def test_is_success(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.is_success is init_kwargs["is_success"] + + @pytest.mark.it( + "Initializes 'request_status_code' attribute with the provided 'status_code' parameter" + ) + def test_request_status_code(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.request_status_code is init_kwargs["status_code"] + + @pytest.mark.it( + "Initializes 'status_description' attribute with the provided 'status_description' parameter" + ) + def test_status_description(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.status_description is init_kwargs["status_description"] + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_iothub_http.NotifyBlobUploadStatusOperation, + op_test_config_class=NotifyBlobUploadStatusOperationTestConfig, + extended_op_instantiation_test_class=NotifyBlobUploadStatusOperationInstantiationTests, +) diff --git a/azure-iot-device/tests/iothub/pipeline/test_pipeline_stages_iothub.py b/azure-iot-device/tests/iothub/pipeline/test_pipeline_stages_iothub.py index e5128dbf4..992121b3a 100644 --- a/azure-iot-device/tests/iothub/pipeline/test_pipeline_stages_iothub.py +++ b/azure-iot-device/tests/iothub/pipeline/test_pipeline_stages_iothub.py @@ -10,43 +10,27 @@ import pytest import sys import threading from concurrent.futures import Future -from azure.iot.device.common import unhandled_exceptions +from azure.iot.device.exceptions import ServiceError +from azure.iot.device.common import handle_exceptions from azure.iot.device.common.pipeline import pipeline_ops_base from azure.iot.device.iothub.pipeline import pipeline_stages_iothub, pipeline_ops_iothub -from tests.common.pipeline.helpers import ( - assert_callback_succeeded, - assert_callback_failed, - all_common_ops, - all_common_events, - all_except, - make_mock_stage, - UnhandledException, -) -from tests.iothub.pipeline.helpers import all_iothub_ops, all_iothub_events +from azure.iot.device.iothub.pipeline.exceptions import PipelineError +from azure.iot.device.iothub.auth.authentication_provider import AuthenticationProvider +from tests.common.pipeline.helpers import StageRunOpTestBase, StageHandlePipelineEventTestBase from tests.common.pipeline import pipeline_stage_test from azure.iot.device.common.models.x509 import X509 from azure.iot.device.iothub.auth.x509_authentication_provider import X509AuthenticationProvider logging.basicConfig(level=logging.DEBUG) - this_module = sys.modules[__name__] - - -# This fixture makes it look like all test in this file tests are running -# inside the pipeline thread. Because this is an autouse fixture, we -# manually add it to the individual test.py files that need it. If, -# instead, we had added it to some conftest.py, it would be applied to -# every tests in every file and we don't want that. -@pytest.fixture(autouse=True) -def apply_fake_pipeline_thread(fake_pipeline_thread): - pass +pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") fake_device_id = "__fake_device_id__" fake_module_id = "__fake_module_id__" fake_hostname = "__fake_hostname__" fake_gateway_hostname = "__fake_gateway_hostname__" -fake_ca_cert = "__fake_ca_cert__" +fake_server_verification_cert = "__fake_server_verification_cert__" fake_sas_token = "__fake_sas_token__" fake_symmetric_key = "Zm9vYmFy" fake_x509_cert_file = "fantastic_beasts" @@ -54,497 +38,697 @@ fake_x509_cert_key_file = "where_to_find_them" fake_pass_phrase = "alohomora" -pipeline_stage_test.add_base_pipeline_stage_tests( - cls=pipeline_stages_iothub.UseAuthProviderStage, - module=this_module, - all_ops=all_common_ops + all_iothub_ops, - handled_ops=[ - pipeline_ops_iothub.SetAuthProviderOperation, - pipeline_ops_iothub.SetX509AuthProviderOperation, - ], - all_events=all_common_events + all_iothub_events, - handled_events=[], - methods_that_enter_pipeline_thread=["on_sas_token_updated"], -) +################### +# COMMON FIXTURES # +################### -def make_mock_sas_token_auth_provider(): - class MockAuthProvider(object): - def get_current_sas_token(self): - return fake_sas_token - - auth_provider = MockAuthProvider() - auth_provider.device_id = fake_device_id - auth_provider.hostname = fake_hostname - return auth_provider +@pytest.fixture(params=[True, False], ids=["With error", "No error"]) +def op_error(request, arbitrary_exception): + if request.param: + return arbitrary_exception + else: + return None -def make_x509_auth_provider_device(): - mock_x509 = X509(fake_x509_cert_file, fake_x509_cert_key_file, fake_pass_phrase) - return X509AuthenticationProvider( - hostname=fake_hostname, device_id=fake_device_id, x509=mock_x509 - ) +@pytest.fixture +def mock_handle_background_exception(mocker): + mock_handler = mocker.patch.object(handle_exceptions, "handle_background_exception") + return mock_handler -def make_x509_auth_provider_module(): - mock_x509 = X509(fake_x509_cert_file, fake_x509_cert_key_file, fake_pass_phrase) - return X509AuthenticationProvider( - x509=mock_x509, hostname=fake_hostname, device_id=fake_device_id, module_id=fake_module_id - ) +########################### +# USE AUTH PROVIDER STAGE # +########################### -different_auth_provider_ops = [ - { - "name": "sas_token_auth", - "current_op_class": pipeline_ops_iothub.SetAuthProviderOperation, - "auth_provider_function_name": make_mock_sas_token_auth_provider, - }, - { - "name": "x509_auth_device", - "current_op_class": pipeline_ops_iothub.SetX509AuthProviderOperation, - "auth_provider_function_name": make_x509_auth_provider_device, - }, - { - "name": "x509_auth_module", - "current_op_class": pipeline_ops_iothub.SetX509AuthProviderOperation, - "auth_provider_function_name": make_x509_auth_provider_module, - }, -] - - -@pytest.mark.parametrize( - "params_auth_provider_ops", - different_auth_provider_ops, - ids=[x["current_op_class"].__name__ for x in different_auth_provider_ops], -) -@pytest.mark.describe("UseAuthProvider - .run_op() -- called with SetAuthProviderOperation") -class TestUseAuthProviderRunOpWithSetAuthProviderOperation(object): +class UseAuthProviderStageTestConfig(object): @pytest.fixture - def stage(self, mocker): - return make_mock_stage(mocker, pipeline_stages_iothub.UseAuthProviderStage) + def cls_type(self): + return pipeline_stages_iothub.UseAuthProviderStage @pytest.fixture - def set_auth_provider(self, callback, params_auth_provider_ops): - op = params_auth_provider_ops["current_op_class"]( - auth_provider=params_auth_provider_ops["auth_provider_function_name"]() - ) - op.callback = callback - return op + def init_kwargs(self): + return {} @pytest.fixture - def set_auth_provider_all_args(self, callback, params_auth_provider_ops): - auth_provider = params_auth_provider_ops["auth_provider_function_name"]() - auth_provider.module_id = fake_module_id - - if not isinstance(auth_provider, X509AuthenticationProvider): - auth_provider.ca_cert = fake_ca_cert - auth_provider.gateway_hostname = fake_gateway_hostname - auth_provider.sas_token = fake_sas_token - op = params_auth_provider_ops["current_op_class"](auth_provider=auth_provider) - op.callback = callback - return op - - @pytest.mark.it("Runs SetIoTHubConnectionArgsOperation op on the next stage") - def test_runs_set_auth_provider_args(self, mocker, stage, set_auth_provider): - stage.next._execute_op = mocker.Mock() - stage.run_op(set_auth_provider) - assert stage.next._execute_op.call_count == 1 - set_args = stage.next._execute_op.call_args[0][0] - assert isinstance(set_args, pipeline_ops_iothub.SetIoTHubConnectionArgsOperation) - - @pytest.mark.it( - "Sets the device_id, and hostname attributes on SetIoTHubConnectionArgsOperation based on the same-names auth_provider attributes" - ) - def test_sets_required_attributes(self, mocker, stage, set_auth_provider): - stage.next._execute_op = mocker.Mock() - stage.run_op(set_auth_provider) - set_args = stage.next._execute_op.call_args[0][0] - assert set_args.device_id == fake_device_id - assert set_args.hostname == fake_hostname - - @pytest.mark.it( - "Sets the gateway_hostname, ca_cert, and module_id attributes to None if they don't exist on the auth_provider object" - ) - def test_defaults_optional_attributes_to_none( - self, mocker, stage, set_auth_provider, params_auth_provider_ops - ): - stage.next._execute_op = mocker.Mock() - stage.run_op(set_auth_provider) - set_args = stage.next._execute_op.call_args[0][0] - assert set_args.gateway_hostname is None - assert set_args.ca_cert is None - if params_auth_provider_ops["name"] == "x509_auth_module": - assert set_args.module_id is not None - else: - assert set_args.module_id is None - - @pytest.mark.it( - "Sets the module_id, gateway_hostname, sas_token, and ca_cert attributes on SetIoTHubConnectionArgsOperation if they exist on the auth_provider object" - ) - def test_sets_optional_attributes( - self, mocker, stage, set_auth_provider_all_args, params_auth_provider_ops - ): - stage.next._execute_op = mocker.Mock() - stage.run_op(set_auth_provider_all_args) - set_args = stage.next._execute_op.call_args[0][0] - assert set_args.module_id == fake_module_id - - if params_auth_provider_ops["name"] == "sas_token_auth": - assert set_args.gateway_hostname == fake_gateway_hostname - assert set_args.ca_cert == fake_ca_cert - assert set_args.sas_token == fake_sas_token - - @pytest.mark.it( - "Handles any Exceptions raised by SetIoTHubConnectionArgsOperation and returns them through the op callback" - ) - def test_set_auth_provider_raises_exception( - self, mocker, stage, fake_exception, set_auth_provider - ): - stage.next._execute_op = mocker.Mock(side_effect=fake_exception) - stage.run_op(set_auth_provider) - assert_callback_failed(op=set_auth_provider, error=fake_exception) - - @pytest.mark.it( - "Allows any BaseExceptions raised by SetIoTHubConnectionArgsOperation to propagate" - ) - def test_set_auth_provider_raises_base_exception( - self, mocker, stage, fake_base_exception, set_auth_provider - ): - stage.next._execute_op = mocker.Mock(side_effect=fake_base_exception) - with pytest.raises(UnhandledException): - stage.run_op(set_auth_provider) - - @pytest.mark.it( - "Retrieves sas_token or x509_certificate on the auth provider and passes the result as the attribute of the next operation" - ) - def test_calls_get_current_sas_token_or_get_x509_certificate( - self, mocker, stage, set_auth_provider, params_auth_provider_ops - ): - - if params_auth_provider_ops["name"] == "sas_token_auth": - spy_method = mocker.spy(set_auth_provider.auth_provider, "get_current_sas_token") - elif "x509_auth" in params_auth_provider_ops["name"]: - spy_method = mocker.spy(set_auth_provider.auth_provider, "get_x509_certificate") - - stage.run_op(set_auth_provider) - assert spy_method.call_count == 1 - set_connection_args_op = stage.next._execute_op.call_args_list[0][0][0] - - if params_auth_provider_ops["name"] == "sas_token_auth": - assert set_connection_args_op.sas_token == fake_sas_token - elif "x509_auth" in params_auth_provider_ops["name"]: - assert set_connection_args_op.client_cert.certificate_file == fake_x509_cert_file - assert set_connection_args_op.client_cert.key_file == fake_x509_cert_key_file - assert set_connection_args_op.client_cert.pass_phrase == fake_pass_phrase - - @pytest.mark.it( - "Calls the callback with no error if the setting sas token or setting certificate operation succeeds" - ) - def test_returns_success_if_set_sas_token_or_set_client_certificate_succeeds( - self, stage, set_auth_provider - ): - stage.run_op(set_auth_provider) - assert_callback_succeeded(op=set_auth_provider) - - @pytest.mark.it( - "Handles any Exceptions raised by setting sas token or setting certificate and returns them through the op callback" - ) - def test_set_sas_token_or_set_client_certificate_raises_exception( - self, mocker, fake_exception, stage, set_auth_provider, params_auth_provider_ops - ): - if params_auth_provider_ops["name"] == "sas_token_auth": - set_auth_provider.auth_provider.get_current_sas_token = mocker.Mock( - side_effect=fake_exception - ) - elif "x509_auth" in params_auth_provider_ops["name"]: - set_auth_provider.auth_provider.get_x509_certificate = mocker.Mock( - side_effect=fake_exception - ) - - stage.run_op(set_auth_provider) - assert_callback_failed(op=set_auth_provider, error=fake_exception) - - @pytest.mark.it( - "Allows any BaseExceptions raised by get_current_sas_token or get_x509_certificate to propagate" - ) - def test_set_sas_token_or_set_client_certificate_raises_base_exception( - self, mocker, fake_base_exception, stage, set_auth_provider, params_auth_provider_ops - ): - if params_auth_provider_ops["name"] == "sas_token_auth": - set_auth_provider.auth_provider.get_current_sas_token = mocker.Mock( - side_effect=fake_base_exception - ) - elif "x509_auth" in params_auth_provider_ops["name"]: - set_auth_provider.auth_provider.get_x509_certificate = mocker.Mock( - side_effect=fake_base_exception - ) - with pytest.raises(UnhandledException): - stage.run_op(set_auth_provider) - - @pytest.mark.it("Sets the on_sas_token_updated_handler handler") - def test_sets_sas_token_updated_handler( - self, mocker, stage, set_auth_provider_all_args, params_auth_provider_ops - ): - if params_auth_provider_ops["name"] != "sas_token_auth": - pytest.mark.skip() - else: - stage.next._execute_op = mocker.Mock() - stage.run_op(set_auth_provider_all_args) - assert ( - set_auth_provider_all_args.auth_provider.on_sas_token_updated_handler - == stage.on_sas_token_updated - ) - - -@pytest.mark.describe("UseAuthProvider - .on_sas_token_updated()") -class TestUseAuthProviderOnSasTokenUpdated(object): - @pytest.fixture - def stage(self, mocker): - stage = make_mock_stage(mocker, pipeline_stages_iothub.UseAuthProviderStage) - auth_provider = mocker.MagicMock() - auth_provider.get_current_sas_token = mocker.MagicMock(return_value=fake_sas_token) - stage.auth_provider = auth_provider + def stage(self, mocker, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() return stage - @pytest.mark.it("Runs as a non-blocking function on the pipeline thread") - def test_runs_non_blocking(self, stage): - threading.current_thread().name = "not_pipeline" - return_value = stage.on_sas_token_updated() - assert isinstance(return_value, Future) - @pytest.mark.it( - "Runs a UpdateSasTokenOperation on the next stage with the sas token from self.auth_provider" - ) - def test_update_sas_token_operation(self, stage): - stage.on_sas_token_updated() - assert stage.next.run_op.call_count == 1 - assert isinstance( - stage.next.run_op.call_args[0][0], pipeline_ops_base.UpdateSasTokenOperation - ) - - @pytest.mark.it( - "Handles any Exceptions raised by the UpdateSasTokenOperation and passes them into the unhandled exception handler" - ) - def test_raises_exception(self, stage, mocker): - threading.current_thread().name = "not_pipeline" - - mocker.spy(unhandled_exceptions, "exception_caught_in_background_thread") - stage.next.run_op.side_effect = Exception - future = stage.on_sas_token_updated() - future.result() - - assert unhandled_exceptions.exception_caught_in_background_thread.call_count == 1 - assert isinstance( - unhandled_exceptions.exception_caught_in_background_thread.call_args[0][0], Exception - ) - - @pytest.mark.it("Allows any BaseExceptions raised by the UpdateSasTokenOperation to propagate") - def test_raises_base_exception(self, stage): - threading.current_thread().name = "not_pipeline" - - stage.next.run_op.side_effect = UnhandledException - future = stage.on_sas_token_updated() - - with pytest.raises(BaseException): - future.result() +class UseAuthProviderStageInstantiationTests(UseAuthProviderStageTestConfig): + @pytest.mark.it("Initializes 'auth_provider' as None") + def test_auth_provider(self, init_kwargs): + stage = pipeline_stages_iothub.UseAuthProviderStage(**init_kwargs) + assert stage.auth_provider is None pipeline_stage_test.add_base_pipeline_stage_tests( - cls=pipeline_stages_iothub.HandleTwinOperationsStage, - module=this_module, - all_ops=all_common_ops + all_iothub_ops, - handled_ops=[ - pipeline_ops_iothub.GetTwinOperation, - pipeline_ops_iothub.PatchTwinReportedPropertiesOperation, - ], - all_events=all_common_events + all_iothub_events, - handled_events=[], + test_module=this_module, + stage_class_under_test=pipeline_stages_iothub.UseAuthProviderStage, + stage_test_config_class=UseAuthProviderStageTestConfig, + extended_stage_instantiation_test_class=UseAuthProviderStageInstantiationTests, ) -@pytest.mark.describe("HandleTwinOperationsStage - .run_op() -- called with GetTwinOperation") -class TestHandleTwinOperationsRunOpWithGetTwin(object): - @pytest.fixture - def stage(self, mocker): - return make_mock_stage(mocker, pipeline_stages_iothub.HandleTwinOperationsStage) +@pytest.mark.describe( + "UseAuthProviderStage - .run_op() -- Called with SetAuthProviderOperation (SAS Authentication)" +) +class TestUseAuthProviderStageRunOpWithSetAuthProviderOperation( + StageRunOpTestBase, UseAuthProviderStageTestConfig +): + # Auth Providers are configured with different values depending on if the higher level client + # is a Device or Module. Parametrize with both possibilities. + # TODO: Eventually would be ideal to test using real auth provider instead of the fake one + # This probably should just wait until auth provider refactor for ease though. + @pytest.fixture(params=["Device", "Module"]) + def fake_auth_provider(self, request, mocker): + class FakeAuthProvider(AuthenticationProvider): + pass + + if request.param == "Device": + fake_auth_provider = FakeAuthProvider(hostname=fake_hostname, device_id=fake_device_id) + else: + fake_auth_provider = FakeAuthProvider( + hostname=fake_hostname, device_id=fake_device_id, module_id=fake_module_id + ) + fake_auth_provider.get_current_sas_token = mocker.MagicMock() + fake_auth_provider.on_sas_token_updated_handler_list = [mocker.MagicMock()] + return fake_auth_provider @pytest.fixture - def op(self, stage, callback): - return pipeline_ops_iothub.GetTwinOperation(callback=callback) - - @pytest.fixture - def twin(self): - return {"Am I a twin": "You bet I am"} - - @pytest.fixture - def twin_as_bytes(self, twin): - return json.dumps(twin).encode("utf-8") + def op(self, mocker, fake_auth_provider): + return pipeline_ops_iothub.SetAuthProviderOperation( + auth_provider=fake_auth_provider, callback=mocker.MagicMock() + ) @pytest.mark.it( - "Runs a SendIotRequestAndWaitForResponseOperation operation on the next stage with request_type='twin', method='GET', resource_location='/', and request_body=' '" + "Sets the operation's authentication provider on the stage as the 'auth_provider' attribute" ) - def test_sends_new_operation(self, stage, op): + def test_set_auth_provider(self, op, stage): + assert stage.auth_provider is None + stage.run_op(op) - assert stage.next.run_op.call_count == 1 - new_op = stage.next.run_op.call_args[0][0] - assert isinstance(new_op, pipeline_ops_base.SendIotRequestAndWaitForResponseOperation) + + assert stage.auth_provider is op.auth_provider + + # NOTE: Because currently auth providers don't have a consistent attribute surface, only some + # have the 'server_verification_cert' and 'gateway_hostname' attributes, so parametrize to show they default to + # None when non-existent. If authentication providers ever receive a uniform surface, this + # parametrization will no longer be required. + @pytest.mark.it( + "Sends a new SetIoTHubConnectionArgsOperation op down the pipeline, containing connection info from the authentication provider" + ) + @pytest.mark.parametrize( + "all_auth_args", [True, False], ids=["All authentication args", "Only guaranteed args"] + ) + def test_send_new_op_down(self, mocker, op, stage, all_auth_args): + if all_auth_args: + op.auth_provider.server_verification_cert = fake_server_verification_cert + op.auth_provider.gateway_hostname = fake_gateway_hostname + + stage.run_op(op) + + # A SetIoTHubConnectionArgsOperation op has been sent down the pipeline + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_iothub.SetIoTHubConnectionArgsOperation) + + # The IoTHubConnectionArgsOperation has details from the auth provider + assert new_op.device_id == op.auth_provider.device_id + assert new_op.module_id == op.auth_provider.module_id + assert new_op.hostname == op.auth_provider.hostname + assert new_op.sas_token is op.auth_provider.get_current_sas_token.return_value + assert new_op.client_cert is None + if all_auth_args: + assert new_op.server_verification_cert == op.auth_provider.server_verification_cert + assert new_op.gateway_hostname == op.auth_provider.gateway_hostname + else: + assert new_op.server_verification_cert is None + assert new_op.gateway_hostname is None + + @pytest.mark.it( + "Completes the original operation upon completion of the SetIoTHubConnectionArgsOperation" + ) + def test_complete_worker(self, op, stage, op_error): + # Run original op + stage.run_op(op) + assert not op.completed + + # A SetIoTHubConnectionArgsOperation op has been sent down the pipeline + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_iothub.SetIoTHubConnectionArgsOperation) + assert not new_op.completed + + # Complete the new op + new_op.complete(error=op_error) + + # Both ops are now completed + assert new_op.completed + assert new_op.error is op_error + assert op.completed + assert op.error is op_error + + +@pytest.mark.describe( + "UseAuthProviderStage - .run_op() -- Called with SetX509AuthProviderOperation (X509 Authentication)" +) +class TestUseAuthProviderStageRunOpWithSetX509AuthProviderOperation( + StageRunOpTestBase, UseAuthProviderStageTestConfig +): + # Auth Providers are configured with different values depending on if the higher level client + # is a Device or Module. Parametrize with both possibilities. + # TODO: Eventually would be ideal to test using real auth provider instead of the fake one + # This probably should just wait until auth provider refactor for ease though. + @pytest.fixture(params=["Device", "Module"]) + def fake_auth_provider(self, request, mocker): + class FakeAuthProvider(AuthenticationProvider): + pass + + if request.param == "Device": + fake_auth_provider = FakeAuthProvider(hostname=fake_hostname, device_id=fake_device_id) + else: + fake_auth_provider = FakeAuthProvider( + hostname=fake_hostname, device_id=fake_device_id, module_id=fake_module_id + ) + fake_auth_provider.get_x509_certificate = mocker.MagicMock() + return fake_auth_provider + + @pytest.fixture + def op(self, mocker, fake_auth_provider): + return pipeline_ops_iothub.SetX509AuthProviderOperation( + auth_provider=fake_auth_provider, callback=mocker.MagicMock() + ) + + @pytest.mark.it( + "Sets the operation's authentication provider on the stage as the 'auth_provider' attribute" + ) + def test_set_auth_provider(self, op, stage): + assert stage.auth_provider is None + + stage.run_op(op) + + assert stage.auth_provider is op.auth_provider + + # NOTE: Because currently auth providers don't have a consistent attribute surface, only some + # have the 'server_verification_cert' and 'gateway_hostname' attributes, so parametrize to show they default to + # None when non-existent. If authentication providers ever receive a uniform surface, this + # parametrization will no longer be required. + @pytest.mark.it( + "Sends a new SetIoTHubConnectionArgsOperation op down the pipeline, containing connection info from the authentication provider" + ) + @pytest.mark.parametrize( + "all_auth_args", [True, False], ids=["All authentication args", "Only guaranteed args"] + ) + def test_send_new_op_down(self, mocker, op, stage, all_auth_args): + if all_auth_args: + op.auth_provider.server_verification_cert = fake_server_verification_cert + op.auth_provider.gateway_hostname = fake_gateway_hostname + + stage.run_op(op) + + # A SetIoTHubConnectionArgsOperation op has been sent down the pipeline + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_iothub.SetIoTHubConnectionArgsOperation) + + # The IoTHubConnectionArgsOperation has details from the auth provider + assert new_op.device_id == op.auth_provider.device_id + assert new_op.module_id == op.auth_provider.module_id + assert new_op.hostname == op.auth_provider.hostname + assert new_op.client_cert is op.auth_provider.get_x509_certificate.return_value + assert new_op.sas_token is None + if all_auth_args: + assert new_op.server_verification_cert == op.auth_provider.server_verification_cert + assert new_op.gateway_hostname == op.auth_provider.gateway_hostname + else: + assert new_op.server_verification_cert is None + assert new_op.gateway_hostname is None + + @pytest.mark.it( + "Completes the original operation upon completion of the SetIoTHubConnectionArgsOperation" + ) + def test_complete_worker(self, op, stage, op_error): + # Run original op + stage.run_op(op) + assert not op.completed + + # A SetIoTHubConnectionArgsOperation op has been sent down the pipeline + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_iothub.SetIoTHubConnectionArgsOperation) + assert not new_op.completed + + # Complete the new op + new_op.complete(error=op_error) + + # Both ops are now completed + assert new_op.completed + assert new_op.error is op_error + assert op.completed + assert op.error is op_error + + +@pytest.mark.describe("UseAuthProviderStage - .run_op() -- Called with arbitrary other operation") +class TestUseAuthProviderStageRunOpWithAribitraryOperation( + StageRunOpTestBase, UseAuthProviderStageTestConfig +): + @pytest.fixture + def op(self, arbitrary_op): + return arbitrary_op + + @pytest.mark.it("Sends the operation down the pipeline") + def test_sends_down(self, mocker, stage, op): + stage.run_op(op) + + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) + assert not op.completed + + +@pytest.mark.describe( + "UseAuthProviderStage - OCCURANCE: SAS Authentication Provider updates SAS token" +) +class TestUseAuthProviderStageWhenAuthProviderGeneratesNewSasToken(UseAuthProviderStageTestConfig): + # Auth Providers are configured with different values depending on if the higher level client + # is a Device or Module. Parametrize with both possibilities. + # TODO: Eventually would be ideal to test using real auth provider instead of the fake one + # This probably should just wait until auth provider refactor for ease though. + @pytest.fixture(params=["Device", "Module"]) + def fake_auth_provider(self, request, mocker): + class FakeAuthProvider(AuthenticationProvider): + pass + + if request.param == "Device": + fake_auth_provider = FakeAuthProvider(hostname=fake_hostname, device_id=fake_device_id) + else: + fake_auth_provider = FakeAuthProvider( + hostname=fake_hostname, device_id=fake_device_id, module_id=fake_module_id + ) + fake_auth_provider.get_current_sas_token = mocker.MagicMock() + fake_auth_provider.on_sas_token_updated_handler_list = [mocker.MagicMock()] + return fake_auth_provider + + @pytest.fixture + def stage(self, mocker, init_kwargs, fake_auth_provider): + stage = pipeline_stages_iothub.UseAuthProviderStage(**init_kwargs) + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + + # Attach an auth provider + set_auth_op = pipeline_ops_iothub.SetAuthProviderOperation( + auth_provider=fake_auth_provider, callback=mocker.MagicMock() + ) + stage.run_op(set_auth_op) + assert stage.auth_provider is fake_auth_provider + stage.send_op_down.reset_mock() + stage.send_event_up.reset_mock() + return stage + + @pytest.mark.it("Sends an UpdateSasTokenOperation with the new SAS token down the pipeline") + def test_generates_new_token(self, mocker, stage): + for x in stage.auth_provider.on_sas_token_updated_handler_list: + x() + + assert stage.send_op_down.call_count == 1 + op = stage.send_op_down.call_args[0][0] + assert isinstance(op, pipeline_ops_base.UpdateSasTokenOperation) + assert op.sas_token is stage.auth_provider.get_current_sas_token.return_value + + @pytest.mark.it( + "Sends the error to the background exception handler, if the UpdateSasTokenOperation is completed with error" + ) + def test_update_fails( + self, mocker, stage, arbitrary_exception, mock_handle_background_exception + ): + for x in stage.auth_provider.on_sas_token_updated_handler_list: + x() + + assert stage.send_op_down.call_count == 1 + op = stage.send_op_down.call_args[0][0] + + assert mock_handle_background_exception.call_count == 0 + + op.complete(error=arbitrary_exception) + assert mock_handle_background_exception.call_count == 1 + assert mock_handle_background_exception.call_args == mocker.call(arbitrary_exception) + + +############################### +# TWIN REQUEST RESPONSE STAGE # +############################### + + +class TwinRequestResponseStageTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_stages_iothub.TwinRequestResponseStage + + @pytest.fixture + def init_kwargs(self): + return {} + + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + return stage + + +pipeline_stage_test.add_base_pipeline_stage_tests( + test_module=this_module, + stage_class_under_test=pipeline_stages_iothub.TwinRequestResponseStage, + stage_test_config_class=TwinRequestResponseStageTestConfig, +) + + +@pytest.mark.describe("TwinRequestResponseStage - .run_op() -- Called with GetTwinOperation") +class TestTwinRequestResponseStageRunOpWithGetTwinOperation( + StageRunOpTestBase, TwinRequestResponseStageTestConfig +): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_iothub.GetTwinOperation(callback=mocker.MagicMock()) + + @pytest.mark.it( + "Sends a new RequestAndResponseOperation down the pipeline, configured to request a twin" + ) + def test_request_and_response_op(self, mocker, stage, op): + stage.run_op(op) + + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_base.RequestAndResponseOperation) assert new_op.request_type == "twin" assert new_op.method == "GET" assert new_op.resource_location == "/" assert new_op.request_body == " " - @pytest.mark.it("Returns an Exception through the op callback if there is no next stage") - def test_runs_with_no_next_stage(self, stage, op): - stage.next = None - stage.run_op(op) - assert_callback_failed(op=op, error=Exception) - - @pytest.mark.it( - "Handles any Exceptions raised by the SendIotRequestAndWaitForResponseOperation and returns them through the op callback" - ) - def test_next_stage_raises_exception(self, stage, op, mocker): - stage.next.run_op.side_effect = Exception - stage.run_op(op) - assert_callback_failed(op=op, error=Exception) - - @pytest.mark.it( - "Allows any BaseExceptions raised by the SendIotRequestAndWaitForResponseOperation to propagate" - ) - def test_next_stage_raises_base_exception(self, stage, op): - stage.next.run_op.side_effect = UnhandledException - with pytest.raises(UnhandledException): - stage.run_op(op) - - @pytest.mark.it( - "Returns any error in the SendIotRequestAndWaitForResponseOperation callback through the op callback" - ) - def test_next_stage_returns_error(self, stage, op): - error = Exception() - - def next_stage_run_op(self, op): - op.error = error - op.callback(op) - - stage.next.run_op = functools.partial(next_stage_run_op, (stage.next,)) - stage.run_op(op) - assert_callback_failed(op=op, error=error) - - @pytest.mark.it( - "Returns an error in the op callback if the SendIotRequestAndWaitForResponseOperation returns a status code >= 300" - ) - def test_next_stage_returns_status_over_300(self, stage, op): - def next_stage_run_op(self, op): - op.status_code = 400 - op.callback(op) - - stage.next.run_op = functools.partial(next_stage_run_op, (stage.next,)) - stage.run_op(op) - assert_callback_failed(op=op, error=Exception) - - @pytest.mark.it( - "Decodes, deserializes, and returns the request_body from SendIotRequestAndWaitForResponseOperation as the twin attribute on the op along with no error if the status code < 300" - ) - def test_next_stage_completes_correctly(self, stage, op, twin, twin_as_bytes): - def next_stage_run_op(self, op): - op.status_code = 200 - op.response_body = twin_as_bytes - op.callback(op) - - stage.next.run_op = functools.partial(next_stage_run_op, (stage.next,)) - stage.run_op(op) - assert_callback_succeeded(op=op) - assert op.twin == twin - @pytest.mark.describe( - "HandleTwinOperationsStage - .run_op() -- called with PatchTwinReportedPropertiesOperation" + "TwinRequestResponseStage - .run_op() -- Called with PatchTwinReportedPropertiesOperation" ) -class TestHandleTwinOperationsRunOpWithPatchTwinReportedProperties(object): +class TestTwinRequestResponseStageRunOpWithPatchTwinReportedPropertiesOperation( + StageRunOpTestBase, TwinRequestResponseStageTestConfig +): + # CT-TODO: parametrize this with realistic json objects @pytest.fixture - def stage(self, mocker): - return make_mock_stage(mocker, pipeline_stages_iothub.HandleTwinOperationsStage) + def json_patch(self): + return {"json_key": "json_val"} @pytest.fixture - def patch(self): - return {"__fake_patch__": "yes"} - - @pytest.fixture - def patch_as_string(self, patch): - return json.dumps(patch) - - @pytest.fixture - def op(self, stage, callback, patch): + def op(self, mocker, json_patch): return pipeline_ops_iothub.PatchTwinReportedPropertiesOperation( - patch=patch, callback=callback + patch=json_patch, callback=mocker.MagicMock() ) @pytest.mark.it( - "Runs a SendIotRequestAndWaitForResponseOperation operation on the next stage with request_type='twin', method='PATCH', resource_location='/properties/reported/', and the request_body attribute set to a stringification of the patch" + "Sends a new RequestAndResponseOperation down the pipeline, configured to send a twin reported properties patch, with the patch serialized as a JSON string" ) - def test_sends_new_operation(self, stage, op, patch_as_string): + def test_request_and_response_op(self, mocker, stage, op): stage.run_op(op) - assert stage.next.run_op.call_count == 1 - new_op = stage.next.run_op.call_args[0][0] - assert isinstance(new_op, pipeline_ops_base.SendIotRequestAndWaitForResponseOperation) + + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_base.RequestAndResponseOperation) assert new_op.request_type == "twin" assert new_op.method == "PATCH" assert new_op.resource_location == "/properties/reported/" - assert new_op.request_body == patch_as_string + assert new_op.request_body == json.dumps(op.patch) - @pytest.mark.it("Returns an Exception through the op callback if there is no next stage") - def test_runs_with_no_next_stage(self, stage, op): - stage.next = None + +@pytest.mark.describe( + "TwinRequestResponseStage - .run_op() -- Called with other arbitrary operation" +) +class TestTwinRequestResponseStageRunOpWithArbitraryOperation( + StageRunOpTestBase, TwinRequestResponseStageTestConfig +): + @pytest.fixture + def op(self, arbitrary_op): + return arbitrary_op + + @pytest.mark.it("Sends the operation down the pipeline") + def test_sends_op_down(self, mocker, stage, op): stage.run_op(op) - assert_callback_failed(op=op, error=Exception) + + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) + + +# TODO: Provide a more accurate set of status codes for tests +@pytest.mark.describe( + "TwinRequestResponseStage - OCCURANCE: RequestAndResponseOperation created from GetTwinOperation is completed" +) +class TestTwinRequestResponseStageWhenRequestAndResponseCreatedFromGetTwinOperationCompleted( + TwinRequestResponseStageTestConfig +): + @pytest.fixture + def get_twin_op(self, mocker): + return pipeline_ops_iothub.GetTwinOperation(callback=mocker.MagicMock()) + + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs, get_twin_op): + stage = cls_type(**init_kwargs) + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + + # Run the GetTwinOperation + stage.run_op(get_twin_op) + + return stage + + @pytest.fixture + def request_and_response_op(self, stage): + assert stage.send_op_down.call_count == 1 + op = stage.send_op_down.call_args[0][0] + assert isinstance(op, pipeline_ops_base.RequestAndResponseOperation) + + # reset the stage mock for convenience + stage.send_op_down.reset_mock() + + return op @pytest.mark.it( - "Handles any Exceptions raised by the SendIotRequestAndWaitForResponseOperation and returns them through the op callback" + "Completes the GetTwinOperation unsuccessfully, with the error from the RequestAndResponseOperation, if the RequestAndResponseOperation is completed unsuccessfully" ) - def test_next_stage_raises_exception(self, stage, op): - stage.next.run_op.side_effect = Exception - stage.run_op(op) - assert_callback_failed(op=op, error=Exception) + @pytest.mark.parametrize( + "has_response_body", [True, False], ids=["With Response Body", "No Response Body"] + ) + @pytest.mark.parametrize( + "status_code", + [ + pytest.param(None, id="Status Code: None"), + pytest.param(200, id="Status Code: 200"), + pytest.param(300, id="Status Code: 300"), + pytest.param(400, id="Status Code: 400"), + pytest.param(500, id="Status Code: 500"), + ], + ) + def test_request_and_response_op_completed_with_err( + self, + stage, + get_twin_op, + request_and_response_op, + arbitrary_exception, + status_code, + has_response_body, + ): + assert not get_twin_op.completed + assert not request_and_response_op.completed + + # NOTE: It shouldn't happen that an operation completed with error has a status code or a + # response body, but it IS possible. + request_and_response_op.status_code = status_code + if has_response_body: + request_and_response_op.response_body = b'{"key": "value"}' + request_and_response_op.complete(error=arbitrary_exception) + + assert request_and_response_op.completed + assert request_and_response_op.error is arbitrary_exception + assert get_twin_op.completed + assert get_twin_op.error is arbitrary_exception + # Twin is NOT returned + assert get_twin_op.twin is None @pytest.mark.it( - "Allows any BaseExceptions raised by the SendIotRequestAndWaitForResponseOperation to propagate" + "Completes the GetTwinOperation unsuccessfully with a ServiceError if the RequestAndResponseOperation is completed successfully with a status code indicating an unsuccessful result from the service" ) - def test_next_stage_raises_base_exception(self, stage, op): - stage.next.run_op.side_effect = UnhandledException - with pytest.raises(UnhandledException): - stage.run_op(op) + @pytest.mark.parametrize( + "has_response_body", [True, False], ids=["With Response Body", "No Response Body"] + ) + @pytest.mark.parametrize( + "status_code", + [ + pytest.param(300, id="Status Code: 300"), + pytest.param(400, id="Status Code: 400"), + pytest.param(500, id="Status Code: 500"), + ], + ) + def test_request_and_response_op_completed_success_with_bad_code( + self, stage, get_twin_op, request_and_response_op, status_code, has_response_body + ): + assert not get_twin_op.completed + assert not request_and_response_op.completed + + request_and_response_op.status_code = status_code + if has_response_body: + request_and_response_op.response_body = b'{"key": "value"}' + request_and_response_op.complete() + + assert request_and_response_op.completed + assert request_and_response_op.error is None + assert get_twin_op.completed + assert isinstance(get_twin_op.error, ServiceError) + # Twin is NOT returned + assert get_twin_op.twin is None @pytest.mark.it( - "Returns any error in the SendIotRequestAndWaitForResponseOperation callback through the op callback" + "Completes the GetTwinOperation successfully (with the JSON deserialized response body from the RequestAndResponseOperation as the twin) if the RequestAndResponseOperation is completed successfully with a status code indicating a successful result from the service" ) - def test_next_stage_returns_error(self, stage, op): - error = Exception() + @pytest.mark.parametrize( + "response_body, expected_twin", + [ + pytest.param(b'{"key": "value"}', {"key": "value"}, id="Twin 1"), + pytest.param(b'{"key1": {"key2": "value"}}', {"key1": {"key2": "value"}}, id="Twin 2"), + pytest.param( + b'{"key1": {"key2": {"key3": "value1", "key4": "value2"}, "key5": "value3"}, "key6": {"key7": "value4"}, "key8": "value5"}', + { + "key1": {"key2": {"key3": "value1", "key4": "value2"}, "key5": "value3"}, + "key6": {"key7": "value4"}, + "key8": "value5", + }, + id="Twin 3", + ), + ], + ) + def test_request_and_response_op_completed_success_with_good_code( + self, stage, get_twin_op, request_and_response_op, response_body, expected_twin + ): + assert not get_twin_op.completed + assert not request_and_response_op.completed - def next_stage_run_op(self, op): - op.error = error - op.callback(op) + request_and_response_op.status_code = 200 + request_and_response_op.response_body = response_body + request_and_response_op.complete() - stage.next.run_op = functools.partial(next_stage_run_op, (stage.next,)) - stage.run_op(op) - assert_callback_failed(op=op, error=error) + assert request_and_response_op.completed + assert request_and_response_op.error is None + assert get_twin_op.completed + assert get_twin_op.error is None + assert get_twin_op.twin == expected_twin + + +@pytest.mark.describe( + "TwinRequestResponseStage - OCCURANCE: RequestAndResponseOperation created from PatchTwinReportedPropertiesOperation is completed" +) +class TestTwinRequestResponseStageWhenRequestAndResponseCreatedFromPatchTwinReportedPropertiesOperation( + TwinRequestResponseStageTestConfig +): + @pytest.fixture + def patch_twin_reported_properties_op(self, mocker): + return pipeline_ops_iothub.PatchTwinReportedPropertiesOperation( + patch={"json_key": "json_val"}, callback=mocker.MagicMock() + ) + + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs, patch_twin_reported_properties_op): + stage = cls_type(**init_kwargs) + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + + # Run the GetTwinOperation + stage.run_op(patch_twin_reported_properties_op) + + return stage + + @pytest.fixture + def request_and_response_op(self, stage): + assert stage.send_op_down.call_count == 1 + op = stage.send_op_down.call_args[0][0] + assert isinstance(op, pipeline_ops_base.RequestAndResponseOperation) + + # reset the stage mock for convenience + stage.send_op_down.reset_mock() + + return op @pytest.mark.it( - "Returns an error in the op callback if the SendIotRequestAndWaitForResponseOperation returns a status code >= 300" + "Completes the PatchTwinReportedPropertiesOperation unsuccessfully, with the error from the RequestAndResponseOperation, if the RequestAndResponseOperation is completed unsuccessfully" ) - def test_next_stage_returns_status_over_300(self, stage, op): - def next_stage_run_op(self, op): - op.status_code = 400 - op.callback(op) + @pytest.mark.parametrize( + "status_code", + [ + pytest.param(None, id="Status Code: None"), + pytest.param(200, id="Status Code: 200"), + pytest.param(300, id="Status Code: 300"), + pytest.param(400, id="Status Code: 400"), + pytest.param(500, id="Status Code: 500"), + ], + ) + def test_request_and_response_op_completed_with_err( + self, + stage, + patch_twin_reported_properties_op, + request_and_response_op, + arbitrary_exception, + status_code, + ): + assert not patch_twin_reported_properties_op.completed + assert not request_and_response_op.completed - stage.next.run_op = functools.partial(next_stage_run_op, (stage.next,)) - stage.run_op(op) - assert_callback_failed(op=op, error=Exception) + # NOTE: It shouldn't happen that an operation completed with error has a status code + # but it IS possible + request_and_response_op.status_code = status_code + request_and_response_op.complete(error=arbitrary_exception) - @pytest.mark.it("Returns no error on the op callback if the status code < 300") - def test_next_stage_completes_correctly(self, stage, op): - def next_stage_run_op(self, op): - op.status_code = 200 - op.callback(op) + assert request_and_response_op.completed + assert request_and_response_op.error is arbitrary_exception + assert patch_twin_reported_properties_op.completed + assert patch_twin_reported_properties_op.error is arbitrary_exception - stage.next.run_op = functools.partial(next_stage_run_op, (stage.next,)) - stage.run_op(op) - assert_callback_succeeded(op=op) + @pytest.mark.it( + "Completes the PatchTwinReportedPropertiesOperation unsuccessfully with a ServiceError if the RequestAndResponseOperation is completed successfully with a status code indicating an unsuccessful result from the service" + ) + @pytest.mark.parametrize( + "status_code", + [ + pytest.param(300, id="Status Code: 300"), + pytest.param(400, id="Status Code: 400"), + pytest.param(500, id="Status Code: 500"), + ], + ) + def test_request_and_response_op_completed_success_with_bad_code( + self, stage, patch_twin_reported_properties_op, request_and_response_op, status_code + ): + assert not patch_twin_reported_properties_op.completed + assert not request_and_response_op.completed + + request_and_response_op.status_code = status_code + request_and_response_op.complete() + + assert request_and_response_op.completed + assert request_and_response_op.error is None + assert patch_twin_reported_properties_op.completed + assert isinstance(patch_twin_reported_properties_op.error, ServiceError) + + @pytest.mark.it( + "Completes the PatchTwinReportedPropertiesOperation successfully if the RequestAndResponseOperation is completed successfully with a status code indicating a successful result from the service" + ) + def test_request_and_response_op_completed_success_with_good_code( + self, stage, patch_twin_reported_properties_op, request_and_response_op + ): + assert not patch_twin_reported_properties_op.completed + assert not request_and_response_op.completed + + request_and_response_op.status_code = 200 + request_and_response_op.complete() + + assert request_and_response_op.completed + assert request_and_response_op.error is None + assert patch_twin_reported_properties_op.completed + assert patch_twin_reported_properties_op.error is None diff --git a/azure-iot-device/tests/iothub/pipeline/test_pipeline_stages_iothub_http.py b/azure-iot-device/tests/iothub/pipeline/test_pipeline_stages_iothub_http.py new file mode 100644 index 000000000..5f8b9e41e --- /dev/null +++ b/azure-iot-device/tests/iothub/pipeline/test_pipeline_stages_iothub_http.py @@ -0,0 +1,880 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import logging +import pytest +import json +import sys +import six.moves.urllib as urllib +from azure.iot.device.common.pipeline import pipeline_stages_base, pipeline_ops_http +from azure.iot.device.iothub.pipeline import ( + pipeline_ops_iothub, + pipeline_ops_iothub_http, + pipeline_stages_iothub_http, + config, +) +from azure.iot.device.exceptions import ServiceError +from tests.common.pipeline.helpers import StageRunOpTestBase +from tests.common.pipeline import pipeline_stage_test +from azure.iot.device import constant as pkg_constant +from azure.iot.device.product_info import ProductInfo + +logging.basicConfig(level=logging.DEBUG) +pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") +this_module = sys.modules[__name__] + +################### +# COMMON FIXTURES # +################### + + +@pytest.fixture(params=[True, False], ids=["With error", "No error"]) +def op_error(request, arbitrary_exception): + if request.param: + return arbitrary_exception + else: + return None + + +@pytest.fixture +def mock_http_path_iothub(mocker): + mock = mocker.patch( + "azure.iot.device.iothub.pipeline.pipeline_stages_iothub_http.http_path_iothub" + ) + return mock + + +################################## +# IOT HUB HTTP TRANSLATION STAGE # +################################## + + +class IoTHubHTTPTranslationStageTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_stages_iothub_http.IoTHubHTTPTranslationStage + + @pytest.fixture + def init_kwargs(self): + return {} + + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + return stage + + +class IoTHubHTTPTranslationStageInstantiationTests(IoTHubHTTPTranslationStageTestConfig): + @pytest.mark.it("Initializes 'device_id' as None") + def test_device_id(self, init_kwargs): + stage = pipeline_stages_iothub_http.IoTHubHTTPTranslationStage(**init_kwargs) + assert stage.device_id is None + + @pytest.mark.it("Initializes 'module_id' as None") + def test_module_id(self, init_kwargs): + stage = pipeline_stages_iothub_http.IoTHubHTTPTranslationStage(**init_kwargs) + assert stage.module_id is None + + @pytest.mark.it("Initializes 'hostname' as None") + def test_hostname(self, init_kwargs): + stage = pipeline_stages_iothub_http.IoTHubHTTPTranslationStage(**init_kwargs) + assert stage.hostname is None + + +pipeline_stage_test.add_base_pipeline_stage_tests( + test_module=this_module, + stage_class_under_test=pipeline_stages_iothub_http.IoTHubHTTPTranslationStage, + stage_test_config_class=IoTHubHTTPTranslationStageTestConfig, + extended_stage_instantiation_test_class=IoTHubHTTPTranslationStageInstantiationTests, +) + + +@pytest.mark.describe( + "IoTHubHTTPTranslationStage - .run_op() -- Called with SetIoTHubConnectionArgsOperation op" +) +class TestIoTHubHTTPTranslationStageRunOpCalledWithConnectionArgsOperation( + IoTHubHTTPTranslationStageTestConfig, StageRunOpTestBase +): + @pytest.fixture(params=["SAS", "X509"]) + def auth_type(self, request): + return request.param + + @pytest.fixture(params=[True, False], ids=["w/ GatewayHostName", "No GatewayHostName"]) + def use_gateway_hostname(self, request): + return request.param + + @pytest.fixture( + params=[True, False], ids=["w/ server verification cert", "No server verification cert"] + ) + def use_server_verification_cert(self, request): + return request.param + + @pytest.fixture(params=["Device", "Module"]) + def op(self, mocker, request, auth_type, use_gateway_hostname, use_server_verification_cert): + kwargs = { + "device_id": "fake_device_id", + "hostname": "fake_hostname", + "callback": mocker.MagicMock(), + } + if request.param == "Module": + kwargs["module_id"] = "fake_module_id" + + if auth_type == "SAS": + kwargs["sas_token"] = "fake_sas_token" + else: + kwargs["client_cert"] = mocker.MagicMock() # representing X509 obj + + if use_gateway_hostname: + kwargs["gateway_hostname"] = "fake_gateway_hostname" + + if use_server_verification_cert: + kwargs["server_verification_cert"] = "fake_server_verification_cert" + + return pipeline_ops_iothub.SetIoTHubConnectionArgsOperation(**kwargs) + + @pytest.mark.it( + "Sets the 'device_id' and 'module_id' values from the op as the stage's 'device_id' and 'module_id' attributes" + ) + def test_cache_device_id_and_module_id(self, stage, op): + assert stage.device_id is None + assert stage.module_id is None + + stage.run_op(op) + + assert stage.device_id == op.device_id + assert stage.module_id == op.module_id + + @pytest.mark.it( + "Sets the 'gateway_hostname' value from the op as the stage's 'hostname' attribute if one is provided, otherwise, use the op's 'hostname'" + ) + def test_cache_hostname(self, stage, op): + assert stage.hostname is None + stage.run_op(op) + + if op.gateway_hostname is not None: + assert stage.hostname == op.gateway_hostname + assert stage.hostname != op.hostname + else: + assert stage.hostname == op.hostname + assert stage.hostname != op.gateway_hostname + + @pytest.mark.it( + "Sends a new SetHTTPConnectionArgsOperation op down the pipeline, configured based on the settings of the SetIoTHubConnectionArgsOperation" + ) + def test_sends_op_down(self, mocker, stage, op): + stage.run_op(op) + + # Op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_http.SetHTTPConnectionArgsOperation) + + # Validate contents of the op + assert new_op.hostname == stage.hostname + assert new_op.server_verification_cert == op.server_verification_cert + assert new_op.client_cert == op.client_cert + assert new_op.sas_token == op.sas_token + + @pytest.mark.it( + "Completes the original SetIoTHubConnectionArgsOperation (with the same error, or lack thereof) if the new SetHTTPConnectionArgsOperation is completed later on" + ) + def test_completing_new_op_completes_original(self, mocker, stage, op_error, op): + stage.run_op(op) + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + + assert not op.completed + assert not new_op.completed + + new_op.complete(error=op_error) + + assert new_op.completed + assert new_op.error is op_error + assert op.completed + assert op.error is op_error + + +@pytest.mark.describe( + "IoTHubHTTPTranslationStage - .run_op() -- Called with MethodInvokeOperation op" +) +class TestIoTHubHTTPTranslationStageRunOpCalledWithMethodInvokeOperation( + IoTHubHTTPTranslationStageTestConfig, StageRunOpTestBase +): + # Because Storage/Blob related functionality is limited to Module, configure the stage for a module + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + pl_config = config.IoTHubPipelineConfig() + stage.pipeline_root = pipeline_stages_base.PipelineRootStage( + pipeline_configuration=pl_config + ) + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + stage.device_id = "fake_device_id" + stage.module_id = "fake_module_id" + stage.hostname = "fake_hostname" + return stage + + @pytest.fixture(params=["Targeting Device Method", "Targeting Module Method"]) + def op(self, mocker, request): + method_params = {"arg1": "val", "arg2": 2, "arg3": True} + if request.param == "Targeting Device Method": + return pipeline_ops_iothub_http.MethodInvokeOperation( + target_device_id="fake_target_device_id", + target_module_id=None, + method_params=method_params, + callback=mocker.MagicMock(), + ) + else: + return pipeline_ops_iothub_http.MethodInvokeOperation( + target_device_id="fake_target_device_id", + target_module_id="fake_target_module_id", + method_params=method_params, + callback=mocker.MagicMock(), + ) + + @pytest.mark.it("Sends a new HTTPRequestAndResponseOperation op down the pipeline") + def test_sends_op_down(self, mocker, stage, op): + stage.run_op(op) + + # Op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) + + @pytest.mark.it( + "Configures the HTTPRequestAndResponseOperation with request details for sending a Method Invoke request" + ) + def test_sends_get_storage_request(self, mocker, stage, op, mock_http_path_iothub): + stage.run_op(op) + + # Op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) + + # Validate request + assert mock_http_path_iothub.get_method_invoke_path.call_count == 1 + assert mock_http_path_iothub.get_method_invoke_path.call_args == mocker.call( + op.target_device_id, op.target_module_id + ) + expected_path = mock_http_path_iothub.get_method_invoke_path.return_value + + assert new_op.method == "POST" + assert new_op.path == expected_path + assert new_op.query_params == "api-version={}".format(pkg_constant.IOTHUB_API_VERSION) + + @pytest.mark.it( + "Configures the HTTPRequestAndResponseOperation with the headers for a Method Invoke request" + ) + @pytest.mark.parametrize( + "custom_user_agent", + [ + pytest.param("", id="No custom user agent"), + pytest.param("MyCustomUserAgent", id="With custom user agent"), + pytest.param( + "My/Custom?User+Agent", id="With custom user agent containing reserved characters" + ), + pytest.param(12345, id="Non-string custom user agent"), + ], + ) + def test_new_op_headers(self, mocker, stage, op, custom_user_agent): + stage.pipeline_root.pipeline_configuration.product_info = custom_user_agent + stage.run_op(op) + + # Op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) + + # Validate headers + expected_user_agent = urllib.parse.quote_plus( + ProductInfo.get_iothub_user_agent() + str(custom_user_agent) + ) + expected_edge_string = "{}/{}".format(stage.device_id, stage.module_id) + + assert new_op.headers["Host"] == stage.hostname + assert new_op.headers["Content-Type"] == "application/json" + assert new_op.headers["Content-Length"] == len(new_op.body) + assert new_op.headers["x-ms-edge-moduleId"] == expected_edge_string + assert new_op.headers["User-Agent"] == expected_user_agent + + @pytest.mark.it( + "Configures the HTTPRequestAndResponseOperation with a body for a Method Invoke request" + ) + def test_new_op_body(self, mocker, stage, op): + stage.run_op(op) + + # Op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) + + # Validate body + assert new_op.body == json.dumps(op.method_params) + + @pytest.mark.it( + "Completes the original MethodInvokeOperation op (no error) if the new HTTPRequestAndResponseOperation op is completed later on (no error) with a status code indicating success" + ) + def test_new_op_completes_with_good_code(self, mocker, stage, op): + stage.run_op(op) + + # Op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) + + # Neither op is completed + assert not op.completed + assert op.error is None + assert not new_op.completed + assert new_op.error is None + + # Complete new op + new_op.response_body = b'{"some_response_key": "some_response_value"}' + new_op.status_code = 200 + new_op.complete() + + # Both ops are now completed successfully + assert new_op.completed + assert new_op.error is None + assert op.completed + assert op.error is None + + @pytest.mark.it( + "Deserializes the completed HTTPRequestAndResponseOperation op's 'response_body' (the received storage info) and set it on the MethodInvokeOperation op as the 'method_response', if the HTTPRequestAndResponseOperation is completed later (no error) with a status code indicating success" + ) + @pytest.mark.parametrize( + "response_body, expected_method_response", + [ + pytest.param( + b'{"key": "val"}', {"key": "val"}, id="Response Body: dict value as bytestring" + ), + pytest.param( + b'{"key": "val", "key2": {"key3": "val2"}}', + {"key": "val", "key2": {"key3": "val2"}}, + id="Response Body: dict value as bytestring", + ), + ], + ) + def test_deserializes_response( + self, mocker, stage, op, response_body, expected_method_response + ): + stage.run_op(op) + + # Op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) + + # Original op has no 'method_response' + assert op.method_response is None + + # Complete new op + new_op.response_body = response_body + new_op.status_code = 200 + new_op.complete() + + # Method Response is set + assert op.method_response == expected_method_response + + @pytest.mark.it( + "Completes the original MethodInvokeOperation op with a ServiceError if the new HTTPRequestAndResponseOperation is completed later on (no error) with a status code indicating non-success" + ) + @pytest.mark.parametrize( + "status_code", + [ + pytest.param(300, id="Status Code: 300"), + pytest.param(400, id="Status Code: 400"), + pytest.param(500, id="Status Code: 500"), + ], + ) + def test_new_op_completes_with_bad_code(self, mocker, stage, op, status_code): + stage.run_op(op) + + # Op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) + + # Neither op is completed + assert not op.completed + assert op.error is None + assert not new_op.completed + assert new_op.error is None + + # Complete new op successfully (but with a bad status code) + new_op.status_code = status_code + new_op.complete() + + # The original op is now completed with a ServiceError + assert new_op.completed + assert new_op.error is None + assert op.completed + assert isinstance(op.error, ServiceError) + + @pytest.mark.it( + "Completes the original MethodInvokeOperation op with the error from the new HTTPRequestAndResponseOperation, if the HTTPRequestAndResponseOperation is completed later on with error" + ) + def test_new_op_completes_with_error(self, mocker, stage, op, arbitrary_exception): + stage.run_op(op) + + # Op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) + + # Neither op is completed + assert not op.completed + assert op.error is None + assert not new_op.completed + assert new_op.error is None + + # Complete new op with error + new_op.complete(error=arbitrary_exception) + + # The original op is now completed with a ServiceError + assert new_op.completed + assert new_op.error is arbitrary_exception + assert op.completed + assert op.error is arbitrary_exception + + +@pytest.mark.describe( + "IoTHubHTTPTranslationStage - .run_op() -- Called with GetStorageInfoOperation op" +) +class TestIoTHubHTTPTranslationStageRunOpCalledWithGetStorageInfoOperation( + IoTHubHTTPTranslationStageTestConfig, StageRunOpTestBase +): + + # Because Storage/Blob related functionality is limited to Devices, configure the stage for a device + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + pl_config = config.IoTHubPipelineConfig() + stage.pipeline_root = pipeline_stages_base.PipelineRootStage( + pipeline_configuration=pl_config + ) + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + stage.device_id = "fake_device_id" + stage.module_id = None + stage.hostname = "fake_hostname" + return stage + + @pytest.fixture + def op(self, mocker): + return pipeline_ops_iothub_http.GetStorageInfoOperation( + blob_name="fake_blob_name", callback=mocker.MagicMock() + ) + + @pytest.mark.it("Sends a new HTTPRequestAndResponseOperation op down the pipeline") + def test_sends_op_down(self, mocker, stage, op): + stage.run_op(op) + + # Op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) + + @pytest.mark.it( + "Configures the HTTPRequestAndResponseOperation with request details for sending a Get Storage Info request" + ) + def test_sends_get_storage_request(self, mocker, stage, op, mock_http_path_iothub): + stage.run_op(op) + + # Op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) + + # Validate request + assert mock_http_path_iothub.get_storage_info_for_blob_path.call_count == 1 + assert mock_http_path_iothub.get_storage_info_for_blob_path.call_args == mocker.call( + stage.device_id + ) + expected_path = mock_http_path_iothub.get_storage_info_for_blob_path.return_value + + assert new_op.method == "POST" + assert new_op.path == expected_path + assert new_op.query_params == "api-version={}".format(pkg_constant.IOTHUB_API_VERSION) + + @pytest.mark.it( + "Configures the HTTPRequestAndResponseOperation with the headers for a Get Storage Info request" + ) + @pytest.mark.parametrize( + "custom_user_agent", + [ + pytest.param("", id="No custom user agent"), + pytest.param("MyCustomUserAgent", id="With custom user agent"), + pytest.param( + "My/Custom?User+Agent", id="With custom user agent containing reserved characters" + ), + pytest.param(12345, id="Non-string custom user agent"), + ], + ) + def test_new_op_headers(self, mocker, stage, op, custom_user_agent): + stage.pipeline_root.pipeline_configuration.product_info = custom_user_agent + stage.run_op(op) + + # Op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) + + # Validate headers + expected_user_agent = urllib.parse.quote_plus( + ProductInfo.get_iothub_user_agent() + str(custom_user_agent) + ) + + assert new_op.headers["Host"] == stage.hostname + assert new_op.headers["Accept"] == "application/json" + assert new_op.headers["Content-Type"] == "application/json" + assert new_op.headers["Content-Length"] == len(new_op.body) + assert new_op.headers["User-Agent"] == expected_user_agent + + @pytest.mark.it( + "Configures the HTTPRequestAndResponseOperation with a body for a Get Storage Info request" + ) + def test_new_op_body(self, mocker, stage, op): + stage.run_op(op) + + # Op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) + + # Validate body + assert new_op.body == '{{"blobName": "{}"}}'.format(op.blob_name) + + @pytest.mark.it( + "Completes the original GetStorageInfoOperation op (no error) if the new HTTPRequestAndResponseOperation is completed later on (no error) with a status code indicating success" + ) + def test_new_op_completes_with_good_code(self, mocker, stage, op): + stage.run_op(op) + + # Op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) + + # Neither op is completed + assert not op.completed + assert op.error is None + assert not new_op.completed + assert new_op.error is None + + # Complete new op + new_op.response_body = b'{"json": "response"}' + new_op.status_code = 200 + new_op.complete() + + # Both ops are now completed successfully + assert new_op.completed + assert new_op.error is None + assert op.completed + assert op.error is None + + @pytest.mark.it( + "Deserializes the completed HTTPRequestAndResponseOperation op's 'response_body' (the received storage info) and set it on the GetStorageInfoOperation as the 'storage_info', if the HTTPRequestAndResponseOperation is completed later (no error) with a status code indicating success" + ) + def test_deserializes_response(self, mocker, stage, op): + stage.run_op(op) + + # Op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) + + # Original op has no 'storage_info' + assert op.storage_info is None + + # Complete new op + new_op.response_body = b'{\ + "hostName": "fake_hostname",\ + "containerName": "fake_container_name",\ + "blobName": "fake_blob_name",\ + "sasToken": "fake_sas_token",\ + "correlationId": "fake_correlation_id"\ + }' + new_op.status_code = 200 + new_op.complete() + + # Storage Info is set + assert op.storage_info == { + "hostName": "fake_hostname", + "containerName": "fake_container_name", + "blobName": "fake_blob_name", + "sasToken": "fake_sas_token", + "correlationId": "fake_correlation_id", + } + + @pytest.mark.it( + "Completes the original GetStorageInfoOperation op with a ServiceError if the new HTTPRequestAndResponseOperation is completed later on (no error) with a status code indicating non-success" + ) + @pytest.mark.parametrize( + "status_code", + [ + pytest.param(300, id="Status Code: 300"), + pytest.param(400, id="Status Code: 400"), + pytest.param(500, id="Status Code: 500"), + ], + ) + def test_new_op_completes_with_bad_code(self, mocker, stage, op, status_code): + stage.run_op(op) + + # Op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) + + # Neither op is completed + assert not op.completed + assert op.error is None + assert not new_op.completed + assert new_op.error is None + + # Complete new op successfully (but with a bad status code) + new_op.status_code = status_code + new_op.complete() + + # The original op is now completed with a ServiceError + assert new_op.completed + assert new_op.error is None + assert op.completed + assert isinstance(op.error, ServiceError) + + @pytest.mark.it( + "Completes the original GetStorageInfoOperation op with the error from the new HTTPRequestAndResponseOperation, if the HTTPRequestAndResponseOperation is completed later on with error" + ) + def test_new_op_completes_with_error(self, mocker, stage, op, arbitrary_exception): + stage.run_op(op) + + # Op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) + + # Neither op is completed + assert not op.completed + assert op.error is None + assert not new_op.completed + assert new_op.error is None + + # Complete new op with error + new_op.complete(error=arbitrary_exception) + + # The original op is now completed with a ServiceError + assert new_op.completed + assert new_op.error is arbitrary_exception + assert op.completed + assert op.error is arbitrary_exception + + +@pytest.mark.describe( + "IoTHubHTTPTranslationStage - .run_op() -- Called with NotifyBlobUploadStatusOperation op" +) +class TestIoTHubHTTPTranslationStageRunOpCalledWithNotifyBlobUploadStatusOperation( + IoTHubHTTPTranslationStageTestConfig, StageRunOpTestBase +): + + # Because Storage/Blob related functionality is limited to Devices, configure the stage for a device + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + pl_config = config.IoTHubPipelineConfig() + stage.pipeline_root = pipeline_stages_base.PipelineRootStage( + pipeline_configuration=pl_config + ) + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + stage.device_id = "fake_device_id" + stage.module_id = None + stage.hostname = "fake_hostname" + return stage + + @pytest.fixture + def op(self, mocker): + return pipeline_ops_iothub_http.NotifyBlobUploadStatusOperation( + correlation_id="fake_correlation_id", + is_success=True, + status_code=203, + status_description="fake_description", + callback=mocker.MagicMock(), + ) + + @pytest.mark.it("Sends a new HTTPRequestAndResponseOperation op down the pipeline") + def test_sends_op_down(self, mocker, stage, op): + stage.run_op(op) + + # Op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) + + @pytest.mark.it( + "Configures the HTTPRequestAndResponseOperation with request details for sending a Notify Blob Upload Status request" + ) + def test_sends_get_storage_request(self, mocker, stage, op, mock_http_path_iothub): + stage.run_op(op) + + # Op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) + + # Validate request + assert mock_http_path_iothub.get_notify_blob_upload_status_path.call_count == 1 + assert mock_http_path_iothub.get_notify_blob_upload_status_path.call_args == mocker.call( + stage.device_id + ) + expected_path = mock_http_path_iothub.get_notify_blob_upload_status_path.return_value + + assert new_op.method == "POST" + assert new_op.path == expected_path + assert new_op.query_params == "api-version={}".format(pkg_constant.IOTHUB_API_VERSION) + + @pytest.mark.it( + "Configures the HTTPRequestAndResponseOperation with the headers for a Notify Blob Upload Status request" + ) + @pytest.mark.parametrize( + "custom_user_agent", + [ + pytest.param("", id="No custom user agent"), + pytest.param("MyCustomUserAgent", id="With custom user agent"), + pytest.param( + "My/Custom?User+Agent", id="With custom user agent containing reserved characters" + ), + pytest.param(12345, id="Non-string custom user agent"), + ], + ) + def test_new_op_headers(self, mocker, stage, op, custom_user_agent): + stage.pipeline_root.pipeline_configuration.product_info = custom_user_agent + stage.run_op(op) + + # Op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) + + # Validate headers + expected_user_agent = urllib.parse.quote_plus( + ProductInfo.get_iothub_user_agent() + str(custom_user_agent) + ) + + assert new_op.headers["Host"] == stage.hostname + assert new_op.headers["Content-Type"] == "application/json; charset=utf-8" + assert new_op.headers["Content-Length"] == len(new_op.body) + assert new_op.headers["User-Agent"] == expected_user_agent + + @pytest.mark.it( + "Configures the HTTPRequestAndResponseOperation with a body for a Notify Blob Upload Status request" + ) + def test_new_op_body(self, mocker, stage, op): + stage.run_op(op) + + # Op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) + + # Validate body + header_dict = { + "correlationId": op.correlation_id, + "isSuccess": op.is_success, + "statusCode": op.request_status_code, + "statusDescription": op.status_description, + } + assert new_op.body == json.dumps(header_dict) + + @pytest.mark.it( + "Completes the original NotifyBlobUploadStatusOperation op (no error) if the new HTTPRequestAndResponseOperation is completed later on (no error) with a status code indicating success" + ) + def test_new_op_completes_with_good_code(self, mocker, stage, op): + stage.run_op(op) + + # Op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) + + # Neither op is completed + assert not op.completed + assert op.error is None + assert not new_op.completed + assert new_op.error is None + + # Complete new op + new_op.status_code = 200 + new_op.complete() + + # Both ops are now completed successfully + assert new_op.completed + assert new_op.error is None + assert op.completed + assert op.error is None + + @pytest.mark.it( + "Completes the original NotifyBlobUploadStatusOperation op with a ServiceError if the new HTTPRequestAndResponseOperation is completed later on (no error) with a status code indicating non-success" + ) + @pytest.mark.parametrize( + "status_code", + [ + pytest.param(300, id="Status Code: 300"), + pytest.param(400, id="Status Code: 400"), + pytest.param(500, id="Status Code: 500"), + ], + ) + def test_new_op_completes_with_bad_code(self, mocker, stage, op, status_code): + stage.run_op(op) + + # Op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) + + # Neither op is completed + assert not op.completed + assert op.error is None + assert not new_op.completed + assert new_op.error is None + + # Complete new op successfully (but with a bad status code) + new_op.status_code = status_code + new_op.complete() + + # The original op is now completed with a ServiceError + assert new_op.completed + assert new_op.error is None + assert op.completed + assert isinstance(op.error, ServiceError) + + @pytest.mark.it( + "Completes the original NotifyBlobUploadStatusOperation op with the error from the new HTTPRequestAndResponseOperation, if the HTTPRequestAndResponseOperation is completed later on with error" + ) + def test_new_op_completes_with_error(self, mocker, stage, op, arbitrary_exception): + stage.run_op(op) + + # Op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_http.HTTPRequestAndResponseOperation) + + # Neither op is completed + assert not op.completed + assert op.error is None + assert not new_op.completed + assert new_op.error is None + + # Complete new op with error + new_op.complete(error=arbitrary_exception) + + # The original op is now completed with a ServiceError + assert new_op.completed + assert new_op.error is arbitrary_exception + assert op.completed + assert op.error is arbitrary_exception diff --git a/azure-iot-device/tests/iothub/pipeline/test_pipeline_stages_iothub_mqtt.py b/azure-iot-device/tests/iothub/pipeline/test_pipeline_stages_iothub_mqtt.py index 6a797b9ca..96c771bdd 100644 --- a/azure-iot-device/tests/iothub/pipeline/test_pipeline_stages_iothub_mqtt.py +++ b/azure-iot-device/tests/iothub/pipeline/test_pipeline_stages_iothub_mqtt.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -import functools import logging import pytest import json @@ -21,22 +20,16 @@ from azure.iot.device.iothub.pipeline import ( pipeline_events_iothub, pipeline_ops_iothub, pipeline_stages_iothub_mqtt, + config, ) +from azure.iot.device.iothub.pipeline.exceptions import OperationError, PipelineError from azure.iot.device.iothub.models.message import Message from azure.iot.device.iothub.models.methods import MethodRequest, MethodResponse -from tests.common.pipeline.helpers import ( - assert_callback_failed, - assert_callback_succeeded, - all_common_ops, - all_common_events, - all_except, - make_mock_stage, - UnhandledException, -) +from tests.common.pipeline.helpers import all_common_ops, all_common_events, StageTestBase from tests.iothub.pipeline.helpers import all_iothub_ops, all_iothub_events from tests.common.pipeline import pipeline_stage_test from azure.iot.device import constant as pkg_constant -import uuid +from azure.iot.device.product_info import ProductInfo logging.basicConfig(level=logging.DEBUG) @@ -57,7 +50,7 @@ fake_device_id = "__fake_device_id__" fake_module_id = "__fake_module_id__" fake_hostname = "__fake_hostname__" fake_gateway_hostname = "__fake_gateway_hostname__" -fake_ca_cert = "__fake_ca_cert__" +fake_server_verification_cert = "__fake_server_verification_cert__" fake_client_cert = "__fake_client_cert__" fake_sas_token = "__fake_sas_token__" @@ -65,8 +58,14 @@ fake_message_id = "ee9e738b-4f47-447a-9892-5b1d1d7ca5" fake_message_id_encoded = "%24.mid=ee9e738b-4f47-447a-9892-5b1d1d7ca5" fake_message_body = "__fake_message_body__" fake_output_name = "__fake_output_name__" +fake_output_name_encoded = "%24.on=__fake_output_name__" fake_content_type = "text/json" fake_content_type_encoded = "%24.ct=text%2Fjson" +fake_content_encoding = "utf-16" +fake_content_encoding_encoded = "%24.ce=utf-16" +default_content_type = "application/json" +default_content_type_encoded = "%24.ct=application%2Fjson" +default_content_encoding_encoded = "%24.ce=utf-8" fake_message = Message(fake_message_body) security_message_interface_id_encoded = "%24.ifid=urn%3Aazureiot%3ASecurity%3ASecurityAgent%3A1" fake_request_id = "__fake_request_id__" @@ -104,7 +103,7 @@ fake_method_request_topic = "$iothub/methods/POST/{}/?$rid={}".format( ) fake_method_request_payload = "{}".encode("utf-8") -encoded_user_agent = urllib.parse.quote_plus(pkg_constant.USER_AGENT) +encoded_user_agent = urllib.parse.quote_plus(ProductInfo.get_iothub_user_agent()) fake_message_user_property_1_key = "is-muggle" fake_message_user_property_1_value = "yes" @@ -116,17 +115,18 @@ fake_message_user_property_2_encoded = "sorted-house=hufflepuff" ops_handled_by_this_stage = [ pipeline_ops_iothub.SetIoTHubConnectionArgsOperation, pipeline_ops_iothub.SendD2CMessageOperation, + pipeline_ops_base.UpdateSasTokenOperation, pipeline_ops_iothub.SendOutputEventOperation, pipeline_ops_iothub.SendMethodResponseOperation, - pipeline_ops_base.SendIotRequestOperation, + pipeline_ops_base.RequestOperation, pipeline_ops_base.EnableFeatureOperation, pipeline_ops_base.DisableFeatureOperation, ] events_handled_by_this_stage = [pipeline_events_mqtt.IncomingMQTTMessageEvent] -pipeline_stage_test.add_base_pipeline_stage_tests( - cls=pipeline_stages_iothub_mqtt.IoTHubMQTTConverterStage, +pipeline_stage_test.add_base_pipeline_stage_tests_old( + cls=pipeline_stages_iothub_mqtt.IoTHubMQTTTranslationStage, module=this_module, all_ops=all_common_ops + all_iothub_ops, handled_ops=ops_handled_by_this_stage, @@ -145,14 +145,14 @@ def create_message_with_user_properties(message_content, is_multiple): def create_security_message(message_content): - msg = Message(message_content, content_type=fake_content_type) + msg = Message(message_content) msg.set_as_security_message() return msg def create_message_with_system_and_user_properties(message_content, is_multiple): if is_multiple: - msg = Message(message_content, message_id=fake_message_id, content_type=fake_content_type) + msg = Message(message_content, message_id=fake_message_id, output_name=fake_output_name) else: msg = Message(message_content, message_id=fake_message_id) @@ -164,7 +164,7 @@ def create_message_with_system_and_user_properties(message_content, is_multiple) def create_security_message_with_system_and_user_properties(message_content, is_multiple): if is_multiple: - msg = Message(message_content, message_id=fake_message_id, content_type=fake_content_type) + msg = Message(message_content, message_id=fake_message_id, output_name=fake_output_name) else: msg = Message(message_content, message_id=fake_message_id) @@ -201,14 +201,9 @@ def create_message_for_output_with_system_and_user_properties(message_content, i @pytest.fixture -def stage(mocker): - return make_mock_stage(mocker, pipeline_stages_iothub_mqtt.IoTHubMQTTConverterStage) - - -@pytest.fixture -def set_connection_args(callback): +def set_connection_args(mocker): return pipeline_ops_iothub.SetIoTHubConnectionArgsOperation( - device_id=fake_device_id, hostname=fake_hostname, callback=callback + device_id=fake_device_id, hostname=fake_hostname, callback=mocker.MagicMock() ) @@ -223,48 +218,86 @@ def set_connection_args_for_module(set_connection_args): return set_connection_args -@pytest.fixture -def stage_configured_for_device(stage, set_connection_args_for_device, mocker): - set_connection_args_for_device.callback = None - stage.run_op(set_connection_args_for_device) - mocker.resetall() +class IoTHubMQTTTranslationStageTestBase(StageTestBase): + @pytest.fixture(autouse=True) + def stage_base_configuration(self, stage, mocker): + class NextStageForTest(pipeline_stages_base.PipelineStage): + def _run_op(self, op): + pass + next = NextStageForTest() + root = ( + pipeline_stages_base.PipelineRootStage(config.IoTHubPipelineConfig()) + .append_stage(stage) + .append_stage(next) + ) -@pytest.fixture -def stage_configured_for_module(stage, set_connection_args_for_module, mocker): - set_connection_args_for_module.callback = None - stage.run_op(set_connection_args_for_module) - mocker.resetall() + mocker.spy(stage, "_run_op") + mocker.spy(stage, "run_op") + mocker.spy(next, "_run_op") + mocker.spy(next, "run_op") -@pytest.fixture(params=["device", "module"]) -def stages_configured_for_both(request, stage, set_connection_args, mocker): - set_connection_args.callback = None - if request.param == "module": - set_connection_args.module_id = fake_module_id - stage.run_op(set_connection_args) - mocker.resetall() + return root + + @pytest.fixture + def stage(self, mocker): + stage = pipeline_stages_iothub_mqtt.IoTHubMQTTTranslationStage() + mocker.spy(stage, "send_op_down") + return stage + + @pytest.fixture + def stage_configured_for_device( + self, stage, stage_base_configuration, set_connection_args_for_device, mocker + ): + set_connection_args_for_device.callback = None + stage.run_op(set_connection_args_for_device) + mocker.resetall() + + @pytest.fixture + def stage_configured_for_module( + self, stage, stage_base_configuration, set_connection_args_for_module, mocker + ): + set_connection_args_for_module.callback = None + stage.run_op(set_connection_args_for_module) + mocker.resetall() + + @pytest.fixture(params=["device", "module"]) + def stages_configured_for_both( + self, request, stage, stage_base_configuration, set_connection_args, mocker + ): + set_connection_args.callback = None + if request.param == "module": + set_connection_args.module_id = fake_module_id + stage.run_op(set_connection_args) + mocker.resetall() @pytest.mark.describe( - "IoTHubMQTTConverterStage - .run_op() -- called with SetIoTHubConnectionArgsOperation" + "IoTHubMQTTTranslationStage - .run_op() -- called with SetIoTHubConnectionArgsOperation" ) -class TestIoTHubMQTTConverterWithSetAuthProviderArgs(object): +class TestIoTHubMQTTConverterWithSetAuthProviderArgs(IoTHubMQTTTranslationStageTestBase): @pytest.mark.it( - "Runs a pipeline_ops_mqtt.SetMQTTConnectionArgsOperation operation on the next stage" + "Runs a pipeline_ops_mqtt.SetMQTTConnectionArgsOperation worker operation on the next stage" ) - def test_runs_set_connection_args(self, stage, set_connection_args): + def test_runs_set_connection_args(self, mocker, stage, set_connection_args): + set_connection_args.spawn_worker_op = mocker.MagicMock() stage.run_op(set_connection_args) - assert stage.next._execute_op.call_count == 1 - new_op = stage.next._execute_op.call_args[0][0] - assert isinstance(new_op, pipeline_ops_mqtt.SetMQTTConnectionArgsOperation) + assert set_connection_args.spawn_worker_op.call_count == 1 + assert ( + set_connection_args.spawn_worker_op.call_args[1]["worker_op_type"] + is pipeline_ops_mqtt.SetMQTTConnectionArgsOperation + ) + worker = set_connection_args.spawn_worker_op.return_value + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(worker) @pytest.mark.it( "Sets connection_args.client_id to auth_provider_args.device_id if auth_provider_args.module_id is None" ) def test_sets_client_id_for_devices(self, stage, set_connection_args): stage.run_op(set_connection_args) - new_op = stage.next._execute_op.call_args[0][0] + new_op = stage.next._run_op.call_args[0][0] assert new_op.client_id == fake_device_id @pytest.mark.it( @@ -272,7 +305,7 @@ class TestIoTHubMQTTConverterWithSetAuthProviderArgs(object): ) def test_sets_client_id_for_modules(self, stage, set_connection_args_for_module): stage.run_op(set_connection_args_for_module) - new_op = stage.next._execute_op.call_args[0][0] + new_op = stage.next._run_op.call_args[0][0] assert new_op.client_id == "{}/{}".format(fake_device_id, fake_module_id) @pytest.mark.it( @@ -280,7 +313,7 @@ class TestIoTHubMQTTConverterWithSetAuthProviderArgs(object): ) def test_sets_hostname_if_no_gateway(self, stage, set_connection_args): stage.run_op(set_connection_args) - new_op = stage.next._execute_op.call_args[0][0] + new_op = stage.next._run_op.call_args[0][0] assert new_op.hostname == fake_hostname @pytest.mark.it( @@ -289,7 +322,7 @@ class TestIoTHubMQTTConverterWithSetAuthProviderArgs(object): def test_sets_hostname_if_yes_gateway(self, stage, set_connection_args): set_connection_args.gateway_hostname = fake_gateway_hostname stage.run_op(set_connection_args) - new_op = stage.next._execute_op.call_args[0][0] + new_op = stage.next._run_op.call_args[0][0] assert new_op.hostname == fake_gateway_hostname @pytest.mark.it( @@ -297,7 +330,7 @@ class TestIoTHubMQTTConverterWithSetAuthProviderArgs(object): ) def test_sets_device_username_if_no_gateway(self, stage, set_connection_args): stage.run_op(set_connection_args) - new_op = stage.next._execute_op.call_args[0][0] + new_op = stage.next._run_op.call_args[0][0] assert new_op.username == "{}/{}/?api-version={}&DeviceClientType={}".format( fake_hostname, fake_device_id, pkg_constant.IOTHUB_API_VERSION, encoded_user_agent ) @@ -308,7 +341,7 @@ class TestIoTHubMQTTConverterWithSetAuthProviderArgs(object): def test_sets_device_username_if_yes_gateway(self, stage, set_connection_args): set_connection_args.gateway_hostname = fake_gateway_hostname stage.run_op(set_connection_args) - new_op = stage.next._execute_op.call_args[0][0] + new_op = stage.next._run_op.call_args[0][0] assert new_op.username == "{}/{}/?api-version={}&DeviceClientType={}".format( fake_hostname, fake_device_id, pkg_constant.IOTHUB_API_VERSION, encoded_user_agent ) @@ -318,7 +351,7 @@ class TestIoTHubMQTTConverterWithSetAuthProviderArgs(object): ) def test_sets_module_username_if_no_gateway(self, stage, set_connection_args_for_module): stage.run_op(set_connection_args_for_module) - new_op = stage.next._execute_op.call_args[0][0] + new_op = stage.next._run_op.call_args[0][0] assert new_op.username == "{}/{}/{}/?api-version={}&DeviceClientType={}".format( fake_hostname, fake_device_id, @@ -333,7 +366,7 @@ class TestIoTHubMQTTConverterWithSetAuthProviderArgs(object): def test_sets_module_username_if_yes_gateway(self, stage, set_connection_args_for_module): set_connection_args_for_module.gateway_hostname = fake_gateway_hostname stage.run_op(set_connection_args_for_module) - new_op = stage.next._execute_op.call_args[0][0] + new_op = stage.next._run_op.call_args[0][0] assert new_op.username == "{}/{}/{}/?api-version={}&DeviceClientType={}".format( fake_hostname, fake_device_id, @@ -342,59 +375,70 @@ class TestIoTHubMQTTConverterWithSetAuthProviderArgs(object): encoded_user_agent, ) - @pytest.mark.it("Sets connection_args.ca_cert to auth_provider.ca_cert") - def test_sets_ca_cert(self, stage, set_connection_args): - set_connection_args.ca_cert = fake_ca_cert + @pytest.mark.it( + "Appends product_info to connection_args.username to if self.pipeline_root.pipeline_configuration.product_info is not None" + ) + @pytest.mark.parametrize( + "fake_product_info, expected_product_info", + [ + ("", ""), + ("__fake:product:info__", "__fake%3Aproduct%3Ainfo__"), + (4, 4), + ( + ["fee,fi,fo,fum"], + "%5B%27fee%2Cfi%2Cfo%2Cfum%27%5D", + ), # URI Encoding for str version of list + ( + {"fake_key": "fake_value"}, + "%7B%27fake_key%27%3A%20%27fake_value%27%7D", + ), # URI Encoding for str version of dict + ], + ) + def test_appends_product_info_to_device_username( + self, stage, set_connection_args, fake_product_info, expected_product_info + ): + set_connection_args.gateway_hostname = fake_gateway_hostname + stage.pipeline_root.pipeline_configuration.product_info = fake_product_info stage.run_op(set_connection_args) - new_op = stage.next._execute_op.call_args[0][0] - assert new_op.ca_cert == fake_ca_cert + new_op = stage.next._run_op.call_args[0][0] + assert new_op.username == "{}/{}/?api-version={}&DeviceClientType={}{}".format( + fake_hostname, + fake_device_id, + pkg_constant.IOTHUB_API_VERSION, + encoded_user_agent, + expected_product_info, + ) + + @pytest.mark.it( + "Sets connection_args.server_verification_cert to auth_provider.server_verification_cert" + ) + def test_sets_server_verification_cert(self, stage, set_connection_args): + set_connection_args.server_verification_cert = fake_server_verification_cert + stage.run_op(set_connection_args) + new_op = stage.next._run_op.call_args[0][0] + assert new_op.server_verification_cert == fake_server_verification_cert @pytest.mark.it("Sets connection_args.client_cert to auth_provider.client_cert") def test_sets_client_cert(self, stage, set_connection_args): set_connection_args.client_cert = fake_client_cert stage.run_op(set_connection_args) - new_op = stage.next._execute_op.call_args[0][0] + new_op = stage.next._run_op.call_args[0][0] assert new_op.client_cert == fake_client_cert @pytest.mark.it("Sets connection_args.sas_token to auth_provider.sas_token.") def test_sets_sas_token(self, stage, set_connection_args): set_connection_args.sas_token = fake_sas_token stage.run_op(set_connection_args) - new_op = stage.next._execute_op.call_args[0][0] + new_op = stage.next._run_op.call_args[0][0] assert new_op.sas_token == fake_sas_token - @pytest.mark.it( - "Calls the SetIoTHubConnectionArgsOperation callback with error if the pipeline_ops_mqtt.SetMQTTConnectionArgsOperation raises an Exception" - ) - def test_set_connection_args_raises_exception( - self, stage, mocker, fake_exception, set_connection_args - ): - stage.next._execute_op = mocker.Mock(side_effect=fake_exception) - stage.run_op(set_connection_args) - assert_callback_failed(op=set_connection_args, error=fake_exception) - - @pytest.mark.it( - "Allows any BaseExceptions raised inside the pipeline_ops_mqtt.SetMQTTConnectionArgsOperation operation to propagate" - ) - def test_set_connection_args_raises_base_exception( - self, stage, mocker, fake_base_exception, set_connection_args - ): - stage.next._execute_op = mocker.Mock(side_effect=fake_base_exception) - with pytest.raises(UnhandledException): - stage.run_op(set_connection_args) - - @pytest.mark.it( - "Calls the SetIoTHubConnectionArgsOperation callback with no error if the pipeline_ops_mqtt.SetMQTTConnectionArgsOperation operation succeeds" - ) - def test_set_connection_args_succeeds(self, stage, mocker, set_connection_args): - stage.run_op(set_connection_args) - assert_callback_succeeded(op=set_connection_args) - @pytest.mark.describe( - "IoTHubMQTTConverterStage - .run_op() -- called with UpdateSasTokenOperation if the transport is disconnected" + "IoTHubMQTTTranslationStage - .run_op() -- called with UpdateSasTokenOperation if the transport is disconnected" ) -class TestIoTHubMQTTConverterWithUpdateSasTokenOperationDisconnected(object): +class TestIoTHubMQTTConverterWithUpdateSasTokenOperationDisconnected( + IoTHubMQTTTranslationStageTestBase +): @pytest.fixture def op(self, mocker): return pipeline_ops_base.UpdateSasTokenOperation( @@ -407,33 +451,24 @@ class TestIoTHubMQTTConverterWithUpdateSasTokenOperationDisconnected(object): @pytest.mark.it("Immediately passes the operation to the next stage") def test_passes_op_immediately(self, stage, op): - op.action = "pend" stage.run_op(op) assert stage.next.run_op.call_count == 1 assert stage.next.run_op.call_args[0][0] == op - @pytest.mark.it("Completes the op with failure if some lower stage returns failure") - def test_lower_stage_update_sas_token_fails(self, stage, op): - op.action = "error" - stage.run_op(op) - assert_callback_failed(op=op, error=Exception) - - @pytest.mark.it("Completes the op with success if some lower stage returns success") - def test_lower_stage_update_sas_token_succeeds(self, stage, op): - op.action = "pass" - stage.run_op(op) - assert_callback_succeeded(op=op) - @pytest.mark.describe( - "IoTHubMQTTConverterStage - .run_op() -- called with UpdateSasTokenOperation if the transport is connected" + "IoTHubMQTTTranslationStage - .run_op() -- called with UpdateSasTokenOperation if the transport is connected" ) -class TestIoTHubMQTTConverterWithUpdateSasTokenOperationConnected(object): +class TestIoTHubMQTTConverterWithUpdateSasTokenOperationConnected( + IoTHubMQTTTranslationStageTestBase +): @pytest.fixture def op(self, mocker): - return pipeline_ops_base.UpdateSasTokenOperation( + op = pipeline_ops_base.UpdateSasTokenOperation( sas_token=fake_sas_token, callback=mocker.MagicMock() ) + mocker.spy(op, "complete") + return op @pytest.fixture(autouse=True) def transport_is_connected(self, stage): @@ -441,27 +476,18 @@ class TestIoTHubMQTTConverterWithUpdateSasTokenOperationConnected(object): @pytest.mark.it("Immediately passes the operation to the next stage") def test_passes_op_immediately(self, stage, op): - op.action = "pend" stage.run_op(op) assert stage.next.run_op.call_count == 1 assert stage.next.run_op.call_args[0][0] == op @pytest.mark.it( - "Completes the op with failure if some lower stage returns failure for the UpdateSasTokenOperation" + "Passes down a ReauthorizeConnectionOperation instead of completing the op with success after the lower level stage returns success for the UpdateSasTokenOperation" ) - def test_lower_stage_update_sas_token_fails(self, stage, op): - op.action = "fail" - stage.run_op(op) - assert_callback_failed(op=op, error=Exception) - - @pytest.mark.it( - "Passes down a ReconnectOperation instead of completing the op with success after the lower level stage returns success for the UpdateSasTokenOperation" - ) - def test_passes_down_reconnect(self, stage, op, mocker): + def test_passes_down_reauthorize_connection(self, stage, op, mocker): def run_op(op): print("in run_op {}".format(op.__class__.__name__)) if isinstance(op, pipeline_ops_base.UpdateSasTokenOperation): - op.callback(op) + op.complete(error=None) else: pass @@ -471,35 +497,53 @@ class TestIoTHubMQTTConverterWithUpdateSasTokenOperationConnected(object): assert stage.next.run_op.call_count == 2 assert stage.next.run_op.call_args_list[0][0][0] == op assert isinstance( - stage.next.run_op.call_args_list[1][0][0], pipeline_ops_base.ReconnectOperation + stage.next.run_op.call_args_list[1][0][0], + pipeline_ops_base.ReauthorizeConnectionOperation, ) - assert op.callback.call_count == 0 + # CT-TODO: Make this test clearer - this below assertion is a bit confusing + # What is happening here is that the run_op defined above for the mock only completes + # ops of type UpdateSasTokenOperation (i.e. variable 'op'). However, completing the + # op triggers a callback which halts the completion, and then spawn a reauthorize_connection worker op, + # which must be completed before full completion of 'op' can occur. However, as the above + # run_op mock only completes ops of type UpdateSasTokenOperation, this never happens, + # thus op is not completed. + assert not op.completed + # CT-TODO: remove this once able. This test does not have a high degree of accuracy, and its contents + # could be tested better once stage tests are restructured. This test is overlapping with tests of + # worker op functionality, that should not be being tested at this granularity here. @pytest.mark.it( - "Completes the op with success if some lower level stage returns success for the ReconnectOperation" + "Completes the op with success if some lower level stage returns success for the ReauthorizeConnectionOperation" ) - def test_reconnect_succeeds(self, stage, op): + def test_reauthorize_connection_succeeds(self, mocker, stage, next_stage_succeeds, op): # default is for stage.next.run_op to return success for all ops stage.run_op(op) assert stage.next.run_op.call_count == 2 assert stage.next.run_op.call_args_list[0][0][0] == op assert isinstance( - stage.next.run_op.call_args_list[1][0][0], pipeline_ops_base.ReconnectOperation + stage.next.run_op.call_args_list[1][0][0], + pipeline_ops_base.ReauthorizeConnectionOperation, ) - assert_callback_succeeded(op=op) + assert op.completed + assert op.complete.call_count == 2 # op was completed twice due to an uncompletion + # most recent call, i.e. one triggered by the successful reauthorize_connection + assert op.complete.call_args == mocker.call(error=None) + + # CT-TODO: As above, remove/restructure ASAP @pytest.mark.it( - "Completes the op with failure if some lower level stage returns failure for the ReconnectOperation" + "Completes the op with failure if some lower level stage returns failure for the ReauthorizeConnectionOperation" ) - def test_reconnect_fails(self, stage, op, mocker, fake_exception): + def test_reauthorize_connection_fails(self, stage, op, mocker, arbitrary_exception): + cb = op.callback_stack[0] + def run_op(op): print("in run_op {}".format(op.__class__.__name__)) if isinstance(op, pipeline_ops_base.UpdateSasTokenOperation): - op.callback(op) - elif isinstance(op, pipeline_ops_base.ReconnectOperation): - op.error = fake_exception - op.callback(op) + op.complete(error=None) + elif isinstance(op, pipeline_ops_base.ReauthorizeConnectionOperation): + op.complete(error=arbitrary_exception) else: pass @@ -509,105 +553,121 @@ class TestIoTHubMQTTConverterWithUpdateSasTokenOperationConnected(object): assert stage.next.run_op.call_count == 2 assert stage.next.run_op.call_args_list[0][0][0] == op assert isinstance( - stage.next.run_op.call_args_list[1][0][0], pipeline_ops_base.ReconnectOperation + stage.next.run_op.call_args_list[1][0][0], + pipeline_ops_base.ReauthorizeConnectionOperation, ) - assert_callback_failed(op=op, error=fake_exception) + assert cb.call_count == 1 + assert cb.call_args == mocker.call(op=op, error=arbitrary_exception) basic_ops = [ { "op_class": pipeline_ops_iothub.SendD2CMessageOperation, - "op_init_kwargs": {"message": fake_message}, + "op_init_kwargs": {"message": fake_message, "callback": None}, "new_op_class": pipeline_ops_mqtt.MQTTPublishOperation, }, { "op_class": pipeline_ops_iothub.SendOutputEventOperation, - "op_init_kwargs": {"message": fake_message}, + "op_init_kwargs": {"message": fake_message, "callback": None}, "new_op_class": pipeline_ops_mqtt.MQTTPublishOperation, }, { "op_class": pipeline_ops_iothub.SendMethodResponseOperation, - "op_init_kwargs": {"method_response": fake_method_response}, + "op_init_kwargs": {"method_response": fake_method_response, "callback": None}, "new_op_class": pipeline_ops_mqtt.MQTTPublishOperation, }, { "op_class": pipeline_ops_base.EnableFeatureOperation, - "op_init_kwargs": {"feature_name": constant.C2D_MSG}, + "op_init_kwargs": {"feature_name": constant.C2D_MSG, "callback": None}, "new_op_class": pipeline_ops_mqtt.MQTTSubscribeOperation, }, { "op_class": pipeline_ops_base.DisableFeatureOperation, - "op_init_kwargs": {"feature_name": constant.C2D_MSG}, + "op_init_kwargs": {"feature_name": constant.C2D_MSG, "callback": None}, "new_op_class": pipeline_ops_mqtt.MQTTUnsubscribeOperation, }, ] +# CT-TODO: simplify this @pytest.mark.parametrize( "params", basic_ops, ids=["{}->{}".format(x["op_class"].__name__, x["new_op_class"].__name__) for x in basic_ops], ) -@pytest.mark.describe("IoTHubMQTTConverterStage - .run_op() -- called with basic MQTT operations") -class TestIoTHubMQTTConverterBasicOperations(object): +@pytest.mark.describe("IoTHubMQTTTranslationStage - .run_op() -- called with basic MQTT operations") +class TestIoTHubMQTTConverterBasicOperations(IoTHubMQTTTranslationStageTestBase): @pytest.fixture - def op(self, params, callback): + def op(self, params, mocker): op = params["op_class"](**params["op_init_kwargs"]) - op.callback = callback + mocker.spy(op, "spawn_worker_op") return op - @pytest.mark.it("Runs an operation on the next stage") - def test_runs_publish(self, params, stage, stages_configured_for_both, op): + @pytest.mark.it("Runs a worker operation on the next stage") + def test_spawn_worker_op(self, params, stage, stages_configured_for_both, op): stage.run_op(op) - new_op = stage.next._execute_op.call_args[0][0] + + assert op.spawn_worker_op.call_count == 1 + assert op.spawn_worker_op.call_args[1]["worker_op_type"] is params["new_op_class"] + new_op = stage.next._run_op.call_args[0][0] assert isinstance(new_op, params["new_op_class"]) - @pytest.mark.it("Calls the original op callback with error if the new_op raises an exception") - def test_operation_raises_exception( - self, params, mocker, stage, stages_configured_for_both, op, fake_exception - ): - stage.next._execute_op = mocker.Mock(side_effect=fake_exception) - stage.run_op(op) - assert_callback_failed(op=op, error=fake_exception) - - @pytest.mark.it("Allows any any BaseExceptions raised in the new_op to propagate") - def test_operation_raises_base_exception( - self, params, mocker, stage, stages_configured_for_both, op, fake_base_exception - ): - stage.next._execute_op = mocker.Mock(side_effect=fake_base_exception) - with pytest.raises(UnhandledException): - stage.run_op(op) - - @pytest.mark.it("Calls the original op callback with no error if the new_op operation succeeds") - def test_operation_succeeds(self, params, stage, stages_configured_for_both, op): - stage.run_op(op) - assert_callback_succeeded(op) - publish_ops = [ { "name": "send telemetry", "stage_type": "device", "op_class": pipeline_ops_iothub.SendD2CMessageOperation, - "op_init_kwargs": {"message": Message(fake_message_body)}, + "op_init_kwargs": {"message": Message(fake_message_body), "callback": None}, "topic": "devices/{}/messages/events/".format(fake_device_id), "publish_payload": fake_message_body, }, + { + "name": "send telemetry with content type and content encoding", + "stage_type": "device", + "op_class": pipeline_ops_iothub.SendD2CMessageOperation, + "op_init_kwargs": { + "message": Message( + fake_message_body, + content_type=fake_content_type, + content_encoding=fake_content_encoding, + ), + "callback": None, + }, + "topic": "devices/{}/messages/events/{}&{}".format( + fake_device_id, fake_content_type_encoded, fake_content_encoding_encoded + ), + "publish_payload": fake_message_body, + }, + { + "name": "send telemetry overriding only the content type", + "stage_type": "device", + "op_class": pipeline_ops_iothub.SendD2CMessageOperation, + "op_init_kwargs": { + "message": Message(fake_message_body, content_type=fake_content_type), + "callback": None, + }, + "topic": "devices/{}/messages/events/{}".format(fake_device_id, fake_content_type_encoded), + "publish_payload": fake_message_body, + }, { "name": "send telemetry with single system property", "stage_type": "device", "op_class": pipeline_ops_iothub.SendD2CMessageOperation, - "op_init_kwargs": {"message": Message(fake_message_body, content_type=fake_content_type)}, - "topic": "devices/{}/messages/events/{}".format(fake_device_id, fake_content_type_encoded), + "op_init_kwargs": { + "message": Message(fake_message_body, output_name=fake_output_name), + "callback": None, + }, + "topic": "devices/{}/messages/events/{}".format(fake_device_id, fake_output_name_encoded), "publish_payload": fake_message_body, }, { "name": "send security message", "stage_type": "device", "op_class": pipeline_ops_iothub.SendD2CMessageOperation, - "op_init_kwargs": {"message": create_security_message(fake_message_body)}, - "topic": "devices/{}/messages/events/{}&{}".format( - fake_device_id, fake_content_type_encoded, security_message_interface_id_encoded + "op_init_kwargs": {"message": create_security_message(fake_message_body), "callback": None}, + "topic": "devices/{}/messages/events/{}".format( + fake_device_id, security_message_interface_id_encoded ), "publish_payload": fake_message_body, }, @@ -617,11 +677,12 @@ publish_ops = [ "op_class": pipeline_ops_iothub.SendD2CMessageOperation, "op_init_kwargs": { "message": Message( - fake_message_body, message_id=fake_message_id, content_type=fake_content_type - ) + fake_message_body, message_id=fake_message_id, output_name=fake_output_name + ), + "callback": None, }, "topic": "devices/{}/messages/events/{}&{}".format( - fake_device_id, fake_message_id_encoded, fake_content_type_encoded + fake_device_id, fake_output_name_encoded, fake_message_id_encoded ), "publish_payload": fake_message_body, }, @@ -630,7 +691,8 @@ publish_ops = [ "stage_type": "device", "op_class": pipeline_ops_iothub.SendD2CMessageOperation, "op_init_kwargs": { - "message": create_message_with_user_properties(fake_message_body, is_multiple=False) + "message": create_message_with_user_properties(fake_message_body, is_multiple=False), + "callback": None, }, "topic": "devices/{}/messages/events/{}".format( fake_device_id, fake_message_user_property_1_encoded @@ -642,7 +704,8 @@ publish_ops = [ "stage_type": "device", "op_class": pipeline_ops_iothub.SendD2CMessageOperation, "op_init_kwargs": { - "message": create_message_with_user_properties(fake_message_body, is_multiple=True) + "message": create_message_with_user_properties(fake_message_body, is_multiple=True), + "callback": None, }, # For more than 1 user property the order could be different, creating 2 different topics "topic1": "devices/{}/messages/events/{}&{}".format( @@ -664,7 +727,8 @@ publish_ops = [ "op_init_kwargs": { "message": create_message_with_system_and_user_properties( fake_message_body, is_multiple=False - ) + ), + "callback": None, }, "topic": "devices/{}/messages/events/{}&{}".format( fake_device_id, fake_message_id_encoded, fake_message_user_property_1_encoded @@ -678,20 +742,21 @@ publish_ops = [ "op_init_kwargs": { "message": create_message_with_system_and_user_properties( fake_message_body, is_multiple=True - ) + ), + "callback": None, }, # For more than 1 user property the order could be different, creating 2 different topics "topic1": "devices/{}/messages/events/{}&{}&{}&{}".format( fake_device_id, + fake_output_name_encoded, fake_message_id_encoded, - fake_content_type_encoded, fake_message_user_property_1_encoded, fake_message_user_property_2_encoded, ), "topic2": "devices/{}/messages/events/{}&{}&{}&{}".format( fake_device_id, + fake_output_name_encoded, fake_message_id_encoded, - fake_content_type_encoded, fake_message_user_property_2_encoded, fake_message_user_property_1_encoded, ), @@ -704,21 +769,22 @@ publish_ops = [ "op_init_kwargs": { "message": create_security_message_with_system_and_user_properties( fake_message_body, is_multiple=True - ) + ), + "callback": None, }, # For more than 1 user property the order could be different, creating 2 different topics "topic1": "devices/{}/messages/events/{}&{}&{}&{}&{}".format( fake_device_id, + fake_output_name_encoded, fake_message_id_encoded, - fake_content_type_encoded, security_message_interface_id_encoded, fake_message_user_property_1_encoded, fake_message_user_property_2_encoded, ), "topic2": "devices/{}/messages/events/{}&{}&{}&{}&{}".format( fake_device_id, + fake_output_name_encoded, fake_message_id_encoded, - fake_content_type_encoded, security_message_interface_id_encoded, fake_message_user_property_2_encoded, fake_message_user_property_1_encoded, @@ -729,23 +795,49 @@ publish_ops = [ "name": "send output", "stage_type": "module", "op_class": pipeline_ops_iothub.SendOutputEventOperation, - "op_init_kwargs": {"message": Message(fake_message_body, output_name=fake_output_name)}, + "op_init_kwargs": { + "message": Message(fake_message_body, output_name=fake_output_name), + "callback": None, + }, "topic": "devices/{}/modules/{}/messages/events/%24.on={}".format( fake_device_id, fake_module_id, fake_output_name ), "publish_payload": fake_message_body, }, + { + "name": "send output with content type and content encoding", + "stage_type": "module", + "op_class": pipeline_ops_iothub.SendOutputEventOperation, + "op_init_kwargs": { + "message": Message( + fake_message_body, + output_name=fake_output_name, + content_type=fake_content_type, + content_encoding=fake_content_encoding, + ), + "callback": None, + }, + "topic": "devices/{}/modules/{}/messages/events/%24.on={}&{}&{}".format( + fake_device_id, + fake_module_id, + fake_output_name, + fake_content_type_encoded, + fake_content_encoding_encoded, + ), + "publish_payload": fake_message_body, + }, { "name": "send output with system properties", "stage_type": "module", "op_class": pipeline_ops_iothub.SendOutputEventOperation, "op_init_kwargs": { "message": Message( - fake_message_body, output_name=fake_output_name, content_type=fake_content_type - ) + fake_message_body, message_id=fake_message_id, output_name=fake_output_name + ), + "callback": None, }, "topic": "devices/{}/modules/{}/messages/events/%24.on={}&{}".format( - fake_device_id, fake_module_id, fake_output_name, fake_content_type_encoded + fake_device_id, fake_module_id, fake_output_name, fake_message_id_encoded ), "publish_payload": fake_message_body, }, @@ -756,7 +848,8 @@ publish_ops = [ "op_init_kwargs": { "message": create_message_for_output_with_user_properties( fake_message_body, is_multiple=False - ) + ), + "callback": None, }, "topic": "devices/{}/modules/{}/messages/events/%24.on={}&{}".format( fake_device_id, fake_module_id, fake_output_name, fake_message_user_property_1_encoded @@ -770,7 +863,8 @@ publish_ops = [ "op_init_kwargs": { "message": create_message_for_output_with_user_properties( fake_message_body, is_multiple=True - ) + ), + "callback": None, }, "topic1": "devices/{}/modules/{}/messages/events/%24.on={}&{}&{}".format( fake_device_id, @@ -795,7 +889,8 @@ publish_ops = [ "op_init_kwargs": { "message": create_message_for_output_with_system_and_user_properties( fake_message_body, is_multiple=False - ) + ), + "callback": None, }, "topic": "devices/{}/modules/{}/messages/events/%24.on={}&{}&{}".format( fake_device_id, @@ -810,7 +905,7 @@ publish_ops = [ "name": "send method result", "stage_type": "both", "op_class": pipeline_ops_iothub.SendMethodResponseOperation, - "op_init_kwargs": {"method_response": fake_method_response}, + "op_init_kwargs": {"method_response": fake_method_response, "callback": None}, "topic": "$iothub/methods/res/__fake_method_status__/?$rid=__fake_request_id__", "publish_payload": json.dumps(fake_method_payload), }, @@ -818,12 +913,12 @@ publish_ops = [ @pytest.mark.parametrize("params", publish_ops, ids=[x["name"] for x in publish_ops]) -@pytest.mark.describe("IoTHubMQTTConverterStage - .run_op() -- called with publish operations") -class TestIoTHubMQTTConverterForPublishOps(object): +@pytest.mark.describe("IoTHubMQTTTranslationStage - .run_op() -- called with publish operations") +class TestIoTHubMQTTConverterForPublishOps(IoTHubMQTTTranslationStageTestBase): @pytest.fixture - def op(self, params, callback): + def op(self, params, mocker): op = params["op_class"](**params["op_init_kwargs"]) - op.callback = callback + op.callback = mocker.MagicMock() return op @pytest.mark.it("Uses the correct topic and encodes message properties string when publishing") @@ -833,7 +928,7 @@ class TestIoTHubMQTTConverterForPublishOps(object): elif params["stage_type"] == "module" and not stage.module_id: pytest.skip() stage.run_op(op) - new_op = stage.next._execute_op.call_args[0][0] + new_op = stage.next._run_op.call_args[0][0] if "multiple user properties" in params["name"]: assert new_op.topic == params["topic1"] or new_op.topic == params["topic2"] else: @@ -842,7 +937,7 @@ class TestIoTHubMQTTConverterForPublishOps(object): @pytest.mark.it("Sends the body in the payload of the MQTT publish operation") def test_sends_correct_body(self, stage, stages_configured_for_both, params, op): stage.run_op(op) - new_op = stage.next._execute_op.call_args[0][0] + new_op = stage.next._run_op.call_args[0][0] assert new_op.payload == params["publish_payload"] @@ -873,9 +968,9 @@ sub_unsub_operations = [ @pytest.mark.describe( - "IoTHubMQTTConverterStage - .run_op() -- called with EnableFeature or DisableFeature" + "IoTHubMQTTTranslationStage - .run_op() -- called with EnableFeature or DisableFeature" ) -class TestIoTHubMQTTConverterWithEnableFeature(object): +class TestIoTHubMQTTConverterWithEnableFeature(IoTHubMQTTTranslationStageTestBase): @pytest.mark.parametrize( "topic_parameters", feature_name_to_subscribe_topic, @@ -897,10 +992,12 @@ class TestIoTHubMQTTConverterWithEnableFeature(object): pytest.skip() elif topic_parameters["stage_type"] == "module" and not stage.module_id: pytest.skip() - stage.next._execute_op = mocker.Mock() - op = op_parameters["op_class"](feature_name=topic_parameters["feature_name"]) + stage.next._run_op = mocker.Mock() + op = op_parameters["op_class"]( + feature_name=topic_parameters["feature_name"], callback=mocker.MagicMock() + ) stage.run_op(op) - new_op = stage.next._execute_op.call_args[0][0] + new_op = stage.next._run_op.call_args[0][0] assert isinstance(new_op, op_parameters["new_op"]) assert new_op.topic == topic_parameters["topic"] @@ -911,29 +1008,30 @@ class TestIoTHubMQTTConverterWithEnableFeature(object): ids=[x["op_class"].__name__ for x in sub_unsub_operations], ) def test_fails_on_invalid_feature_name( - self, mocker, stage, stages_configured_for_both, op_parameters, callback + self, mocker, stage, stages_configured_for_both, op_parameters ): - op = op_parameters["op_class"](feature_name=invalid_feature_name, callback=callback) - callback.reset_mock() + op = op_parameters["op_class"]( + feature_name=invalid_feature_name, callback=mocker.MagicMock() + ) + mocker.spy(op, "complete") stage.run_op(op) - assert callback.call_count == 1 - callback_arg = op.callback.call_args[0][0] - assert callback_arg == op - assert isinstance(callback_arg.error, KeyError) + assert op.complete.call_count == 1 + assert isinstance(op.complete.call_args[1]["error"], KeyError) + # assert_callback_failed(op=op, error=KeyError) @pytest.fixture def add_pipeline_root(stage, mocker): - root = pipeline_stages_base.PipelineRootStage() + root = pipeline_stages_base.PipelineRootStage(mocker.MagicMock()) mocker.spy(root, "handle_pipeline_event") stage.previous = root stage.pipeline_root = root @pytest.mark.describe( - "IoTHubMQTTConverterStage - .handle_pipeline_event() -- called with unmatched topic" + "IoTHubMQTTTranslationStage - .handle_pipeline_event() -- called with unmatched topic" ) -class TestIoTHubMQTTConverterHandlePipelineEvent(object): +class TestIoTHubMQTTConverterHandlePipelineEvent(IoTHubMQTTTranslationStageTestBase): @pytest.mark.it("Passes up any mqtt messages with topics that aren't matched by this stage") def test_passes_up_mqtt_message_with_unknown_topic( self, stage, stages_configured_for_both, add_pipeline_root, mocker @@ -954,9 +1052,9 @@ def c2d_event(): @pytest.mark.describe( - "IoTHubMQTTConverterStage - .handle_pipeline_event() -- called with C2D topic" + "IoTHubMQTTTranslationStage - .handle_pipeline_event() -- called with C2D topic" ) -class TestIoTHubMQTTConverterHandlePipelineEventC2D(object): +class TestIoTHubMQTTConverterHandlePipelineEventC2D(IoTHubMQTTTranslationStageTestBase): @pytest.mark.it( "Converts mqtt message with topic devices/device_id/message/devicebound/ to c2d event" ) @@ -999,8 +1097,8 @@ class TestIoTHubMQTTConverterHandlePipelineEventC2D(object): assert stage.previous.handle_pipeline_event.call_args == mocker.call(event) -@pytest.mark.describe("IotHubMQTTConverter - .run_op() -- called with SendIotRequestOperation") -class TestIotHubMQTTConverterWithSendIotRequest(object): +@pytest.mark.describe("IotHubMQTTConverter - .run_op() -- called with RequestOperation") +class TestIotHubMQTTConverterWithSendIotRequest(IoTHubMQTTTranslationStageTestBase): @pytest.fixture def fake_request_type(self): return "twin" @@ -1033,84 +1131,51 @@ class TestIotHubMQTTConverterWithSendIotRequest(object): fake_resource_location, fake_request_body, fake_request_id, - callback, + mocker, ): - return pipeline_ops_base.SendIotRequestOperation( + op = pipeline_ops_base.RequestOperation( request_type=fake_request_type, method=fake_method, resource_location=fake_resource_location, request_body=fake_request_body, request_id=fake_request_id, - callback=callback, + callback=mocker.MagicMock(), ) + mocker.spy(op, "complete") + mocker.spy(op, "spawn_worker_op") + return op - @pytest.mark.it( - "calls the op callback with a NotImplementedError if request_type is not 'twin'" - ) + @pytest.mark.it("calls the op callback with an OperationError if request_type is not 'twin'") def test_sends_bad_request_type(self, stage, op): op.request_type = "not_twin" stage.run_op(op) - assert_callback_failed(op=op, error=NotImplementedError) + assert op.complete.call_count == 1 + assert isinstance(op.complete.call_args[1]["error"], OperationError) @pytest.mark.it( - "Runs an MQTTPublishOperation on the next stage with the topic formated as '$iothub/twin/{method}{resource_location}?$rid={request_id}' and the payload as the request_body" + "Runs an MQTTPublishOperation as a worker op on the next stage with the topic formated as '$iothub/twin/{method}{resource_location}?$rid={request_id}' and the payload as the request_body" ) def test_sends_new_operation( self, stage, op, fake_method, fake_resource_location, fake_request_id, fake_request_body ): stage.run_op(op) - assert stage.next.run_op.call_count == 1 - new_op = stage.next.run_op.call_args[0][0] - assert isinstance(new_op, pipeline_ops_mqtt.MQTTPublishOperation) - assert new_op.topic == "$iothub/twin/{method}{resource_location}?$rid={request_id}".format( - method=fake_method, resource_location=fake_resource_location, request_id=fake_request_id + assert op.spawn_worker_op.call_count == 1 + assert ( + op.spawn_worker_op.call_args[1]["worker_op_type"] + is pipeline_ops_mqtt.MQTTPublishOperation ) - assert new_op.payload == fake_request_body - - @pytest.mark.it("Returns an Exception through the op callback if there is no next stage") - def test_runs_with_no_next_stage(self, stage, op): - stage.next = None - stage.run_op(op) - assert_callback_failed(op=op, error=Exception) - - @pytest.mark.it( - "Handles any Exceptions raised by the MQTTPublishOperation and returns them through the op callback" - ) - def test_next_stage_raises_exception(self, stage, op): - stage.next.run_op.side_effect = Exception - stage.run_op(op) - assert_callback_failed(op=op, error=Exception) - - @pytest.mark.it("Allows any BaseExceptions raised by the MQTTPublishOperation to propagate") - def test_next_stage_raises_base_exception(self, stage, op): - stage.next.run_op.side_effect = UnhandledException - with pytest.raises(UnhandledException): - stage.run_op(op) - - @pytest.mark.it( - "Returns op.error as the MQTTPublishOperation error in the op callback if the MQTTPublishOperation returned an error in its operation callback" - ) - def test_publish_op_returns_failure(self, stage, op): - error = Exception() - - def next_stage_run_op(self, op): - op.error = error - op.callback(op) - - stage.next.run_op = functools.partial(next_stage_run_op, (stage.next,)) - stage.run_op(op) - assert_callback_failed(op=op, error=error) - - @pytest.mark.it( - "Returns op.error=None in the operation callback if the MQTTPublishOperation returned op.error=None in its operation callback" - ) - def test_publish_op_returns_success(self, stage, op): - def next_stage_run_op(self, op): - op.callback(op) - - stage.next.run_op = functools.partial(next_stage_run_op, (stage.next,)) - stage.run_op(op) - assert_callback_succeeded(op=op) + assert stage.next.run_op.call_count == 1 + worker_op = stage.next.run_op.call_args[0][0] + assert isinstance(worker_op, pipeline_ops_mqtt.MQTTPublishOperation) + assert ( + worker_op.topic + == "$iothub/twin/{method}{resource_location}?$rid={request_id}".format( + method=fake_method, + resource_location=fake_resource_location, + request_id=fake_request_id, + ) + ) + assert worker_op.payload == fake_request_body @pytest.fixture @@ -1121,9 +1186,9 @@ def input_message_event(): @pytest.mark.describe( - "IoTHubMQTTConverterStage - .handle_pipeline_event() -- called with input message topic" + "IoTHubMQTTTranslationStage - .handle_pipeline_event() -- called with input message topic" ) -class TestIoTHubMQTTConverterHandlePipelineEventInputMessages(object): +class TestIoTHubMQTTConverterHandlePipelineEventInputMessages(IoTHubMQTTTranslationStageTestBase): @pytest.mark.it( "Converts mqtt message with topic devices/device_id/modules/module_id/inputs/input_name/ to input event" ) @@ -1187,9 +1252,9 @@ def method_request_event(): @pytest.mark.describe( - "IoTHubMQTTConverterStage - .handle_pipeline_event() -- called with method request topic" + "IoTHubMQTTTranslationStage - .handle_pipeline_event() -- called with method request topic" ) -class TestIoTHubMQTTConverterHandlePipelineEventMethodRequets(object): +class TestIoTHubMQTTConverterHandlePipelineEventMethodRequets(IoTHubMQTTTranslationStageTestBase): @pytest.mark.it( "Converts mqtt messages with topic $iothub/methods/POST/{method name}/?$rid={request id} to method request events" ) @@ -1241,7 +1306,7 @@ class TestIoTHubMQTTConverterHandlePipelineEventMethodRequets(object): @pytest.mark.describe( "IotHubMQTTConverter - .handle_pipeline_event() -- called with twin response topic" ) -class TestIotHubMQTTConverterHandlePipelineEventTwinResponse(object): +class TestIotHubMQTTConverterHandlePipelineEventTwinResponse(IoTHubMQTTTranslationStageTestBase): @pytest.fixture def fake_request_id(self): return "__fake_request_id__" @@ -1291,7 +1356,7 @@ class TestIotHubMQTTConverterHandlePipelineEventTwinResponse(object): stage.device_id = fake_device_id @pytest.mark.it( - "Calls .handle_pipeline_event() on the previous stage with an IotResponseEvent, with request_id and status_code as attributes extracted from the topic and the response_body attirbute set to the payload" + "Calls .handle_pipeline_event() on the previous stage with an ResponseEvent, with request_id and status_code as attributes extracted from the topic and the response_body attirbute set to the payload" ) def test_extracts_request_id_status_code_and_payload( self, @@ -1305,19 +1370,21 @@ class TestIotHubMQTTConverterHandlePipelineEventTwinResponse(object): stage.handle_pipeline_event(event=fake_event) assert stage.previous.handle_pipeline_event.call_count == 1 new_event = stage.previous.handle_pipeline_event.call_args[0][0] - assert isinstance(new_event, pipeline_events_base.IotResponseEvent) + assert isinstance(new_event, pipeline_events_base.ResponseEvent) assert new_event.status_code == fake_status_code assert new_event.request_id == fake_request_id assert new_event.response_body == fake_payload - @pytest.mark.it("Calls the unhandled exception handler if there is no previous stage") + @pytest.mark.it( + "Calls the unhandled exception handler with a PipelineError if there is no previous stage" + ) def test_no_previous_stage( self, stage, fixup_stage_for_test, fake_event, unhandled_error_handler ): stage.previous = None stage.handle_pipeline_event(fake_event) assert unhandled_error_handler.call_count == 1 - assert isinstance(unhandled_error_handler.call_args[0][0], NotImplementedError) + assert isinstance(unhandled_error_handler.call_args[0][0], PipelineError) @pytest.mark.it( "Calls the unhandled exception handler if the requet_id is missing from the topic name" @@ -1371,7 +1438,7 @@ class TestIotHubMQTTConverterHandlePipelineEventTwinResponse(object): @pytest.mark.describe( "IotHubMQTTConverter - .handle_pipeline_event() -- called with twin patch topic" ) -class TestIotHubMQTTConverterHandlePipelineEventTwinPatch(object): +class TestIotHubMQTTConverterHandlePipelineEventTwinPatch(IoTHubMQTTTranslationStageTestBase): @pytest.fixture def fake_topic_name(self): return "$iothub/twin/PATCH/properties/desired" @@ -1414,14 +1481,16 @@ class TestIotHubMQTTConverterHandlePipelineEventTwinPatch(object): assert isinstance(new_event, pipeline_events_iothub.TwinDesiredPropertiesPatchEvent) assert new_event.patch == fake_patch - @pytest.mark.it("Calls the unhandled exception handler if there is no previous stage") + @pytest.mark.it( + "Calls the unhandled exception handler with a PipelineError if there is no previous stage" + ) def test_no_previous_stage( self, stage, fixup_stage_for_test, fake_event, unhandled_error_handler ): stage.previous = None stage.handle_pipeline_event(fake_event) assert unhandled_error_handler.call_count == 1 - assert isinstance(unhandled_error_handler.call_args[0][0], NotImplementedError) + assert isinstance(unhandled_error_handler.call_args[0][0], PipelineError) @pytest.mark.it("Calls the unhandled exception handler if the payload is not a Bytes object") def test_payload_not_bytes( diff --git a/azure-iot-device/tests/iothub/test_sync_clients.py b/azure-iot-device/tests/iothub/test_sync_clients.py index 7db03bbf7..a52427044 100644 --- a/azure-iot-device/tests/iothub/test_sync_clients.py +++ b/azure-iot-device/tests/iothub/test_sync_clients.py @@ -12,21 +12,29 @@ import os import io import six from azure.iot.device.iothub import IoTHubDeviceClient, IoTHubModuleClient +from azure.iot.device import exceptions as client_exceptions from azure.iot.device.iothub.pipeline import IoTHubPipeline, constant +from azure.iot.device.iothub.pipeline import exceptions as pipeline_exceptions from azure.iot.device.iothub.models import Message, MethodRequest -from azure.iot.device.iothub.sync_inbox import SyncClientInbox, InboxEmpty +from azure.iot.device.iothub.sync_inbox import SyncClientInbox from azure.iot.device.iothub.auth import IoTEdgeError -import azure.iot.device.iothub.sync_clients as sync_clients - +from azure.iot.device import constant as device_constant logging.basicConfig(level=logging.DEBUG) -# automatically mock the pipeline for all tests in this file. + +# automatically mock the mqtt pipeline for all tests in this file. @pytest.fixture(autouse=True) -def mock_pipeline_init(mocker): +def mock_mqtt_pipeline_init(mocker): return mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipeline") +# automatically mock the http pipeline for all tests in this file. +@pytest.fixture(autouse=True) +def mock_http_pipeline_init(mocker): + return mocker.patch("azure.iot.device.iothub.pipeline.HTTPPipeline") + + ################ # SHARED TESTS # ################ @@ -34,30 +42,42 @@ class SharedClientInstantiationTests(object): @pytest.mark.it( "Stores the IoTHubPipeline from the 'iothub_pipeline' parameter in the '_iothub_pipeline' attribute" ) - def test_iothub_pipeline_attribute(self, client_class, iothub_pipeline): - client = client_class(iothub_pipeline) + def test_iothub_pipeline_attribute(self, client_class, iothub_pipeline, http_pipeline): + client = client_class(iothub_pipeline, http_pipeline) assert client._iothub_pipeline is iothub_pipeline + @pytest.mark.it( + "Stores the HTTPPipeline from the 'http_pipeline' parameter in the '_http_pipeline' attribute" + ) + def test_sets_http_pipeline_attribute(self, client_class, iothub_pipeline, http_pipeline): + client = client_class(iothub_pipeline, http_pipeline) + + assert client._http_pipeline is http_pipeline + @pytest.mark.it("Sets on_connected handler in the IoTHubPipeline") - def test_sets_on_connected_handler_in_pipeline(self, client_class, iothub_pipeline): - client = client_class(iothub_pipeline) + def test_sets_on_connected_handler_in_pipeline( + self, client_class, iothub_pipeline, http_pipeline + ): + client = client_class(iothub_pipeline, http_pipeline) assert client._iothub_pipeline.on_connected is not None assert client._iothub_pipeline.on_connected == client._on_connected @pytest.mark.it("Sets on_disconnected handler in the IoTHubPipeline") - def test_sets_on_disconnected_handler_in_pipeline(self, client_class, iothub_pipeline): - client = client_class(iothub_pipeline) + def test_sets_on_disconnected_handler_in_pipeline( + self, client_class, iothub_pipeline, http_pipeline + ): + client = client_class(iothub_pipeline, http_pipeline) assert client._iothub_pipeline.on_disconnected is not None assert client._iothub_pipeline.on_disconnected == client._on_disconnected @pytest.mark.it("Sets on_method_request_received handler in the IoTHubPipeline") def test_sets_on_method_request_received_handler_in_pipleline( - self, client_class, iothub_pipeline + self, client_class, iothub_pipeline, http_pipeline ): - client = client_class(iothub_pipeline) + client = client_class(iothub_pipeline, http_pipeline) assert client._iothub_pipeline.on_method_request_received is not None assert ( @@ -66,90 +86,176 @@ class SharedClientInstantiationTests(object): ) -class SharedClientCreateFromConnectionStringTests(object): +class SharedClientCreateMethodUserOptionTests(object): + # In these tests we patch the entire 'auth' library instead of specific auth providers in order + # to make them more generic, and applicable across all creation methods. + + @pytest.fixture + def option_test_required_patching(self, mocker): + """Override this fixture in a subclass if unique patching is required""" + pass + @pytest.mark.it( - "Uses the connection string and CA certificate combination to create a SymmetricKeyAuthenticationProvider" + "Sets the 'product_info' user option parameter on the PipelineConfig, if provided" ) - @pytest.mark.parametrize( - "ca_cert", - [ - pytest.param(None, id="No CA certificate"), - pytest.param("some-certificate", id="With CA certificate"), - ], + def test_product_info_option( + self, + option_test_required_patching, + client_create_method, + create_method_args, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + + product_info = "MyProductInfo" + client_create_method(*create_method_args, product_info=product_info) + + # Get configuration object, and ensure it was used for both protocol pipelines + assert mock_mqtt_pipeline_init.call_count == 1 + config = mock_mqtt_pipeline_init.call_args[0][1] + assert config == mock_http_pipeline_init.call_args[0][1] + + assert config.product_info == product_info + + @pytest.mark.it( + "Sets the 'websockets' user option parameter on the PipelineConfig, if provided" ) - def test_auth_provider_creation(self, mocker, client_class, connection_string, ca_cert): + def test_websockets_option( + self, + option_test_required_patching, + client_create_method, + create_method_args, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + + client_create_method(*create_method_args, websockets=True) + + # Get configuration object, and ensure it was used for both protocol pipelines + assert mock_mqtt_pipeline_init.call_count == 1 + config = mock_mqtt_pipeline_init.call_args[0][1] + assert config == mock_http_pipeline_init.call_args[0][1] + + assert config.websockets + + # TODO: Show that input in the wrong format is formatted to the correct one. This test exists + # in the IoTHubPipelineConfig object already, but we do not currently show that this is felt + # from the API level. + @pytest.mark.it("Sets the 'cipher' user option parameter on the PipelineConfig, if provided") + def test_cipher_option( + self, + option_test_required_patching, + client_create_method, + create_method_args, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + cipher = "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256" + client_create_method(*create_method_args, cipher=cipher) + + # Get configuration object, and ensure it was used for both protocol pipelines + assert mock_mqtt_pipeline_init.call_count == 1 + config = mock_mqtt_pipeline_init.call_args[0][1] + assert config == mock_http_pipeline_init.call_args[0][1] + + assert config.cipher == cipher + + @pytest.mark.it( + "Sets the 'server_verification_cert' user option parameter on the AuthenticationProvider, if provided" + ) + def test_server_verification_cert_option( + self, + option_test_required_patching, + client_create_method, + create_method_args, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + server_verification_cert = "fake_server_verification_cert" + client_create_method(*create_method_args, server_verification_cert=server_verification_cert) + + # Get auth provider object, and ensure it was used for both protocol pipelines + auth = mock_mqtt_pipeline_init.call_args[0][0] + assert auth == mock_http_pipeline_init.call_args[0][0] + + assert auth.server_verification_cert == server_verification_cert + + @pytest.mark.it("Raises a TypeError if an invalid user option parameter is provided") + def test_invalid_option( + self, option_test_required_patching, client_create_method, create_method_args + ): + with pytest.raises(TypeError): + client_create_method(*create_method_args, invalid_option="some_value") + + @pytest.mark.it("Sets default user options if none are provided") + def test_default_options( + self, + option_test_required_patching, + client_create_method, + create_method_args, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + client_create_method(*create_method_args) + + # Get configuration object, and ensure it was used for both protocol pipelines + assert mock_mqtt_pipeline_init.call_count == 1 + config = mock_mqtt_pipeline_init.call_args[0][1] + assert config == mock_http_pipeline_init.call_args[0][1] + + # Get auth provider object, and ensure it was used for both protocol pipelines + auth = mock_mqtt_pipeline_init.call_args[0][0] + assert auth == mock_http_pipeline_init.call_args[0][0] + + assert config.product_info == "" + assert not config.websockets + assert not config.cipher + assert auth.server_verification_cert is None + + +class SharedClientCreateFromConnectionStringTests(object): + @pytest.mark.it("Uses the connection string to create a SymmetricKeyAuthenticationProvider") + def test_auth_provider_creation(self, mocker, client_class, connection_string): mock_auth_parse = mocker.patch( "azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider" ).parse - args = (connection_string,) - kwargs = {} - if ca_cert: - kwargs["ca_cert"] = ca_cert - client_class.create_from_connection_string(*args, **kwargs) + client_class.create_from_connection_string(connection_string) assert mock_auth_parse.call_count == 1 assert mock_auth_parse.call_args == mocker.call(connection_string) - assert mock_auth_parse.return_value.ca_cert is ca_cert @pytest.mark.it("Uses the SymmetricKeyAuthenticationProvider to create an IoTHubPipeline") - @pytest.mark.parametrize( - "ca_cert", - [ - pytest.param(None, id="No CA certificate"), - pytest.param("some-certificate", id="With CA certificate"), - ], - ) def test_pipeline_creation( - self, mocker, client_class, connection_string, ca_cert, mock_pipeline_init + self, mocker, client_class, connection_string, mock_mqtt_pipeline_init ): mock_auth = mocker.patch( "azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider" ).parse.return_value - args = (connection_string,) - kwargs = {} - if ca_cert: - kwargs["ca_cert"] = ca_cert - client_class.create_from_connection_string(*args, **kwargs) + mock_config_init = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipelineConfig") - assert mock_pipeline_init.call_count == 1 - assert mock_pipeline_init.call_args == mocker.call(mock_auth) + client_class.create_from_connection_string(connection_string) + + assert mock_mqtt_pipeline_init.call_count == 1 + assert mock_mqtt_pipeline_init.call_args == mocker.call( + mock_auth, mock_config_init.return_value + ) @pytest.mark.it("Uses the IoTHubPipeline to instantiate the client") - @pytest.mark.parametrize( - "ca_cert", - [ - pytest.param(None, id="No CA certificate"), - pytest.param("some-certificate", id="With CA certificate"), - ], - ) - def test_client_instantiation(self, mocker, client_class, connection_string, ca_cert): + def test_client_instantiation(self, mocker, client_class, connection_string): mock_pipeline = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipeline").return_value + mock_pipeline_http = mocker.patch( + "azure.iot.device.iothub.pipeline.HTTPPipeline" + ).return_value spy_init = mocker.spy(client_class, "__init__") - args = (connection_string,) - kwargs = {} - if ca_cert: - kwargs["ca_cert"] = ca_cert - client_class.create_from_connection_string(*args, **kwargs) - + client_class.create_from_connection_string(connection_string) assert spy_init.call_count == 1 - assert spy_init.call_args == mocker.call(mocker.ANY, mock_pipeline) + assert spy_init.call_args == mocker.call(mocker.ANY, mock_pipeline, mock_pipeline_http) @pytest.mark.it("Returns the instantiated client") - @pytest.mark.parametrize( - "ca_cert", - [ - pytest.param(None, id="No CA certificate"), - pytest.param("some-certificate", id="With CA certificate"), - ], - ) - def test_returns_client(self, client_class, connection_string, ca_cert): - args = (connection_string,) - kwargs = {} - if ca_cert: - kwargs["ca_cert"] = ca_cert - client = client_class.create_from_connection_string(*args, **kwargs) + def test_returns_client(self, client_class, connection_string): + client = client_class.create_from_connection_string(connection_string) assert isinstance(client, client_class) @@ -173,63 +279,6 @@ class SharedClientCreateFromConnectionStringTests(object): client_class.create_from_connection_string(bad_cs) -class SharedClientCreateFromSharedAccessSignature(object): - @pytest.mark.it("Uses the SAS token to create a SharedAccessSignatureAuthenticationProvider") - def test_auth_provider_creation(self, mocker, client_class, sas_token_string): - mock_auth_parse = mocker.patch( - "azure.iot.device.iothub.auth.SharedAccessSignatureAuthenticationProvider" - ).parse - - client_class.create_from_shared_access_signature(sas_token_string) - - assert mock_auth_parse.call_count == 1 - assert mock_auth_parse.call_args == mocker.call(sas_token_string) - - @pytest.mark.it( - "Uses the SharedAccessSignatureAuthenticationProvider to create an IoTHubPipeline" - ) - def test_pipeline_creation(self, mocker, client_class, sas_token_string, mock_pipeline_init): - mock_auth = mocker.patch( - "azure.iot.device.iothub.auth.SharedAccessSignatureAuthenticationProvider" - ).parse.return_value - - client_class.create_from_shared_access_signature(sas_token_string) - - assert mock_pipeline_init.call_count == 1 - assert mock_pipeline_init.call_args == mocker.call(mock_auth) - - @pytest.mark.it("Uses the IoTHubPipeline to instantiate the client") - def test_client_instantiation(self, mocker, client_class, sas_token_string): - mock_pipeline = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipeline").return_value - spy_init = mocker.spy(client_class, "__init__") - - client_class.create_from_shared_access_signature(sas_token_string) - - assert spy_init.call_count == 1 - assert spy_init.call_args == mocker.call(mocker.ANY, mock_pipeline) - - @pytest.mark.it("Returns the instantiated client") - def test_returns_client(self, mocker, client_class, sas_token_string): - client = client_class.create_from_shared_access_signature(sas_token_string) - assert isinstance(client, client_class) - - # TODO: If auth package was refactored to use SasToken class, tests from that - # class would increase the coverage here. - @pytest.mark.it("Raises ValueError when given an invalid SAS token") - @pytest.mark.parametrize( - "bad_sas", - [ - pytest.param(object(), id="Non-string input"), - pytest.param( - "SharedAccessSignature sr=Invalid&sig=Invalid&se=Invalid", id="Malformed SAS token" - ), - ], - ) - def test_raises_value_error_on_bad_sas_token(self, client_class, bad_sas): - with pytest.raises(ValueError): - client_class.create_from_shared_access_signature(bad_sas) - - class WaitsForEventCompletion(object): def add_event_completion_checks(self, mocker, pipeline_function, args=[], kwargs={}): event_init_mock = mocker.patch.object(threading, "Event") @@ -269,18 +318,57 @@ class SharedClientConnectTests(WaitsForEventCompletion): ) client_manual_cb.connect() - @pytest.mark.it("Raises an error if the `connect` pipeline operation calls back with an error") + @pytest.mark.it( + "Raises a client error if the `connect` pipeline operation calls back with a pipeline error" + ) + @pytest.mark.parametrize( + "pipeline_error,client_error", + [ + pytest.param( + pipeline_exceptions.ConnectionDroppedError, + client_exceptions.ConnectionDroppedError, + id="ConnectionDroppedError->ConnectionDroppedError", + ), + pytest.param( + pipeline_exceptions.ConnectionFailedError, + client_exceptions.ConnectionFailedError, + id="ConnectionFailedError->ConnectionFailedError", + ), + pytest.param( + pipeline_exceptions.UnauthorizedError, + client_exceptions.CredentialError, + id="UnauthorizedError->CredentialError", + ), + pytest.param( + pipeline_exceptions.ProtocolClientError, + client_exceptions.ClientError, + id="ProtocolClientError->ClientError", + ), + pytest.param( + pipeline_exceptions.TlsExchangeAuthError, + client_exceptions.ClientError, + id="TlsExchangeAuthError->ClientError", + ), + pytest.param( + pipeline_exceptions.ProtocolProxyError, + client_exceptions.ClientError, + id="ProtocolProxyError->ClientError", + ), + pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), + ], + ) def test_raises_error_on_pipeline_op_error( - self, mocker, client_manual_cb, iothub_pipeline_manual_cb, fake_error + self, mocker, client_manual_cb, iothub_pipeline_manual_cb, pipeline_error, client_error ): + my_pipeline_error = pipeline_error() self.add_event_completion_checks( mocker=mocker, pipeline_function=iothub_pipeline_manual_cb.connect, - kwargs={"error": fake_error}, + kwargs={"error": my_pipeline_error}, ) - with pytest.raises(fake_error.__class__) as e_info: + with pytest.raises(client_error) as e_info: client_manual_cb.connect() - assert e_info.value is fake_error + assert e_info.value.__cause__ is my_pipeline_error class SharedClientDisconnectTests(WaitsForEventCompletion): @@ -301,19 +389,31 @@ class SharedClientDisconnectTests(WaitsForEventCompletion): client_manual_cb.disconnect() @pytest.mark.it( - "Raises an error if the `disconnect` pipeline operation calls back with an error" + "Raises a client error if the `disconnect` pipeline operation calls back with a pipeline error" + ) + @pytest.mark.parametrize( + "pipeline_error,client_error", + [ + pytest.param( + pipeline_exceptions.ProtocolClientError, + client_exceptions.ClientError, + id="ProtocolClientError->ClientError", + ), + pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), + ], ) def test_raises_error_on_pipeline_op_error( - self, mocker, client_manual_cb, iothub_pipeline_manual_cb, fake_error + self, mocker, client_manual_cb, iothub_pipeline_manual_cb, pipeline_error, client_error ): + my_pipeline_error = pipeline_error() self.add_event_completion_checks( mocker=mocker, pipeline_function=iothub_pipeline_manual_cb.disconnect, - kwargs={"error": fake_error}, + kwargs={"error": my_pipeline_error}, ) - with pytest.raises(fake_error.__class__) as e_info: + with pytest.raises(client_error) as e_info: client_manual_cb.disconnect() - assert e_info.value is fake_error + assert e_info.value.__cause__ is my_pipeline_error class SharedClientDisconnectEventTests(object): @@ -343,19 +443,52 @@ class SharedClientSendD2CMessageTests(WaitsForEventCompletion): client_manual_cb.send_message(message) @pytest.mark.it( - "Raises an error if the `send_message` pipeline operation calls back with an error" + "Raises a client error if the `send_message` pipeline operation calls back with a pipeline error" + ) + @pytest.mark.parametrize( + "pipeline_error,client_error", + [ + pytest.param( + pipeline_exceptions.ConnectionDroppedError, + client_exceptions.ConnectionDroppedError, + id="ConnectionDroppedError->ConnectionDroppedError", + ), + pytest.param( + pipeline_exceptions.ConnectionFailedError, + client_exceptions.ConnectionFailedError, + id="ConnectionFailedError->ConnectionFailedError", + ), + pytest.param( + pipeline_exceptions.UnauthorizedError, + client_exceptions.CredentialError, + id="UnauthorizedError->CredentialError", + ), + pytest.param( + pipeline_exceptions.ProtocolClientError, + client_exceptions.ClientError, + id="ProtocolClientError->ClientError", + ), + pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), + ], ) def test_raises_error_on_pipeline_op_error( - self, mocker, client_manual_cb, iothub_pipeline_manual_cb, message, fake_error + self, + mocker, + client_manual_cb, + iothub_pipeline_manual_cb, + message, + pipeline_error, + client_error, ): + my_pipeline_error = pipeline_error() self.add_event_completion_checks( mocker=mocker, pipeline_function=iothub_pipeline_manual_cb.send_message, - kwargs={"error": fake_error}, + kwargs={"error": my_pipeline_error}, ) - with pytest.raises(fake_error.__class__) as e_info: + with pytest.raises(client_error) as e_info: client_manual_cb.send_message(message) - assert e_info.value is fake_error + assert e_info.value.__cause__ is my_pipeline_error @pytest.mark.it( "Wraps 'message' input parameter in a Message object if it is not a Message object" @@ -380,6 +513,42 @@ class SharedClientSendD2CMessageTests(WaitsForEventCompletion): assert isinstance(sent_message, Message) assert sent_message.data == message_input + @pytest.mark.it("Raises error when message data size is greater than 256 KB") + def test_raises_error_when_message_data_greater_than_256(self, client, iothub_pipeline): + data_input = "serpensortia" * 25600 + message = Message(data_input) + with pytest.raises(ValueError) as e_info: + client.send_message(message) + assert "256 KB" in e_info.value.args[0] + assert iothub_pipeline.send_message.call_count == 0 + + @pytest.mark.it("Raises error when message size is greater than 256 KB") + def test_raises_error_when_message_size_greater_than_256(self, client, iothub_pipeline): + data_input = "serpensortia" + message = Message(data_input) + message.custom_properties["spell"] = data_input * 25600 + with pytest.raises(ValueError) as e_info: + client.send_message(message) + assert "256 KB" in e_info.value.args[0] + assert iothub_pipeline.send_message.call_count == 0 + + @pytest.mark.it("Does not raises error when message data size is equal to 256 KB") + def test_raises_error_when_message_data_equal_to_256(self, client, iothub_pipeline): + data_input = "a" * 262095 + message = Message(data_input) + # This check was put as message class may undergo the default content type encoding change + # and the above calculation will change. + # Had to do greater than check for python 2. Ideally should be not equal check + if message.get_size() > device_constant.TELEMETRY_MESSAGE_SIZE_LIMIT: + assert False + + client.send_message(message) + + assert iothub_pipeline.send_message.call_count == 1 + sent_message = iothub_pipeline.send_message.call_args[0][0] + assert isinstance(sent_message, Message) + assert sent_message.data == data_input + class SharedClientReceiveMethodRequestTests(object): @pytest.mark.it("Implicitly enables methods feature if not already enabled") @@ -521,26 +690,24 @@ class SharedClientReceiveMethodRequestTests(object): # did not return until after the delay. @pytest.mark.it( - "Raises InboxEmpty exception after a timeout while blocking, in blocking mode with a specified timeout" + "Returns None after a timeout while blocking, in blocking mode with a specified timeout" ) @pytest.mark.parametrize( "method_name", [pytest.param(None, id="Generic Method"), pytest.param("method_x", id="Named Method")], ) def test_times_out_waiting_for_message_blocking_mode(self, client, method_name): - with pytest.raises(InboxEmpty): - client.receive_method_request(method_name, block=True, timeout=0.01) + result = client.receive_method_request(method_name, block=True, timeout=0.01) + assert result is None - @pytest.mark.it( - "Raises InboxEmpty exception immediately if there are no messages, in nonblocking mode" - ) + @pytest.mark.it("Returns None immediately if there are no messages, in nonblocking mode") @pytest.mark.parametrize( "method_name", [pytest.param(None, id="Generic Method"), pytest.param("method_x", id="Named Method")], ) def test_no_message_in_inbox_nonblocking_mode(self, client, method_name): - with pytest.raises(InboxEmpty): - client.receive_method_request(method_name, block=False) + result = client.receive_method_request(method_name, block=False) + assert result is None class SharedClientSendMethodResponseTests(WaitsForEventCompletion): @@ -563,19 +730,52 @@ class SharedClientSendMethodResponseTests(WaitsForEventCompletion): client_manual_cb.send_method_response(method_response) @pytest.mark.it( - "Raises an error if the `send_method_response` pipeline operation calls back with an error" + "Raises a client error if the `send_method_response` pipeline operation calls back with a pipeline error" + ) + @pytest.mark.parametrize( + "pipeline_error,client_error", + [ + pytest.param( + pipeline_exceptions.ConnectionDroppedError, + client_exceptions.ConnectionDroppedError, + id="ConnectionDroppedError->ConnectionDroppedError", + ), + pytest.param( + pipeline_exceptions.ConnectionFailedError, + client_exceptions.ConnectionFailedError, + id="ConnectionFailedError->ConnectionFailedError", + ), + pytest.param( + pipeline_exceptions.UnauthorizedError, + client_exceptions.CredentialError, + id="UnauthorizedError->CredentialError", + ), + pytest.param( + pipeline_exceptions.ProtocolClientError, + client_exceptions.ClientError, + id="ProtocolClientError->ClientError", + ), + pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), + ], ) def test_raises_error_on_pipeline_op_error( - self, mocker, client_manual_cb, iothub_pipeline_manual_cb, method_response, fake_error + self, + mocker, + client_manual_cb, + iothub_pipeline_manual_cb, + method_response, + pipeline_error, + client_error, ): + my_pipeline_error = pipeline_error() self.add_event_completion_checks( mocker=mocker, pipeline_function=iothub_pipeline_manual_cb.send_method_response, - kwargs={"error": fake_error}, + kwargs={"error": my_pipeline_error}, ) - with pytest.raises(fake_error.__class__) as e_info: + with pytest.raises(client_error) as e_info: client_manual_cb.send_method_response(method_response) - assert e_info.value is fake_error + assert e_info.value.__cause__ is my_pipeline_error class SharedClientGetTwinTests(WaitsForEventCompletion): @@ -623,18 +823,47 @@ class SharedClientGetTwinTests(WaitsForEventCompletion): ) client_manual_cb.get_twin() - @pytest.mark.it("Raises an error if the `get_twin` pipeline operation calls back with an error") + @pytest.mark.it( + "Raises a client error if the `get_twin` pipeline operation calls back with a pipeline error" + ) + @pytest.mark.parametrize( + "pipeline_error,client_error", + [ + pytest.param( + pipeline_exceptions.ConnectionDroppedError, + client_exceptions.ConnectionDroppedError, + id="ConnectionDroppedError->ConnectionDroppedError", + ), + pytest.param( + pipeline_exceptions.ConnectionFailedError, + client_exceptions.ConnectionFailedError, + id="ConnectionFailedError->ConnectionFailedError", + ), + pytest.param( + pipeline_exceptions.UnauthorizedError, + client_exceptions.CredentialError, + id="UnauthorizedError->CredentialError", + ), + pytest.param( + pipeline_exceptions.ProtocolClientError, + client_exceptions.ClientError, + id="ProtocolClientError->ClientError", + ), + pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), + ], + ) def test_raises_error_on_pipeline_op_error( - self, mocker, client_manual_cb, iothub_pipeline_manual_cb, fake_error + self, mocker, client_manual_cb, iothub_pipeline_manual_cb, pipeline_error, client_error ): + my_pipeline_error = pipeline_error() self.add_event_completion_checks( mocker=mocker, pipeline_function=iothub_pipeline_manual_cb.get_twin, - kwargs={"error": fake_error}, + kwargs={"error": my_pipeline_error}, ) - with pytest.raises(fake_error.__class__) as e_info: + with pytest.raises(client_error) as e_info: client_manual_cb.get_twin() - assert e_info.value is fake_error + assert e_info.value.__cause__ is my_pipeline_error @pytest.mark.it("Returns the twin that the pipeline returned") def test_verifies_twin_returned( @@ -701,19 +930,52 @@ class SharedClientPatchTwinReportedPropertiesTests(WaitsForEventCompletion): client_manual_cb.patch_twin_reported_properties(twin_patch_reported) @pytest.mark.it( - "Raises an error if the `patch_twin_reported_properties` pipeline operation calls back with an error" + "Raises a client error if the `patch_twin_reported_properties` pipeline operation calls back with a pipeline error" + ) + @pytest.mark.parametrize( + "pipeline_error,client_error", + [ + pytest.param( + pipeline_exceptions.ConnectionDroppedError, + client_exceptions.ConnectionDroppedError, + id="ConnectionDroppedError->ConnectionDroppedError", + ), + pytest.param( + pipeline_exceptions.ConnectionFailedError, + client_exceptions.ConnectionFailedError, + id="ConnectionFailedError->ConnectionFailedError", + ), + pytest.param( + pipeline_exceptions.UnauthorizedError, + client_exceptions.CredentialError, + id="UnauthorizedError->CredentialError", + ), + pytest.param( + pipeline_exceptions.ProtocolClientError, + client_exceptions.ClientError, + id="ProtocolClientError->ClientError", + ), + pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), + ], ) def test_raises_error_on_pipeline_op_error( - self, mocker, client_manual_cb, iothub_pipeline_manual_cb, twin_patch_reported, fake_error + self, + mocker, + client_manual_cb, + iothub_pipeline_manual_cb, + twin_patch_reported, + pipeline_error, + client_error, ): + my_pipeline_error = pipeline_error() self.add_event_completion_checks( mocker=mocker, pipeline_function=iothub_pipeline_manual_cb.patch_twin_reported_properties, - kwargs={"error": fake_error}, + kwargs={"error": my_pipeline_error}, ) - with pytest.raises(fake_error.__class__) as e_info: + with pytest.raises(client_error) as e_info: client_manual_cb.patch_twin_reported_properties(twin_patch_reported) - assert e_info.value is fake_error + assert e_info.value.__cause__ is my_pipeline_error class SharedClientReceiveTwinDesiredPropertiesPatchTests(object): @@ -802,18 +1064,30 @@ class SharedClientReceiveTwinDesiredPropertiesPatchTests(object): # did not return until after the delay. @pytest.mark.it( - "Raises InboxEmpty exception after a timeout while blocking, in blocking mode with a specified timeout" + "Returns None after a timeout while blocking, in blocking mode with a specified timeout" ) def test_times_out_waiting_for_message_blocking_mode(self, client): - with pytest.raises(InboxEmpty): - client.receive_twin_desired_properties_patch(block=True, timeout=0.01) + result = client.receive_twin_desired_properties_patch(block=True, timeout=0.01) + assert result is None - @pytest.mark.it( - "Raises InboxEmpty exception immediately if there are no patches, in nonblocking mode" - ) + @pytest.mark.it("Returns None immediately if there are no patches, in nonblocking mode") def test_no_message_in_inbox_nonblocking_mode(self, client): - with pytest.raises(InboxEmpty): - client.receive_twin_desired_properties_patch(block=False) + result = client.receive_twin_desired_properties_patch(block=False) + assert result is None + + +class SharedClientPROPERTYConnectedTests(object): + @pytest.mark.it("Cannot be changed") + def test_read_only(self, client): + with pytest.raises(AttributeError): + client.connected = not client.connected + + @pytest.mark.it("Reflects the value of the root stage property of the same name") + def test_reflects_pipeline_property(self, client, iothub_pipeline): + iothub_pipeline.connected = True + assert client.connected + iothub_pipeline.connected = False + assert not client.connected ################ @@ -825,18 +1099,18 @@ class IoTHubDeviceClientTestsConfig(object): return IoTHubDeviceClient @pytest.fixture - def client(self, iothub_pipeline): + def client(self, iothub_pipeline, http_pipeline): """This client automatically resolves callbacks sent to the pipeline. It should be used for the majority of tests. """ - return IoTHubDeviceClient(iothub_pipeline) + return IoTHubDeviceClient(iothub_pipeline, http_pipeline) @pytest.fixture - def client_manual_cb(self, iothub_pipeline_manual_cb): + def client_manual_cb(self, iothub_pipeline_manual_cb, http_pipeline_manual_cb): """This client requires manual triggering of the callbacks sent to the pipeline. It should only be used for tests where manual control fo a callback is required. """ - return IoTHubDeviceClient(iothub_pipeline_manual_cb) + return IoTHubDeviceClient(iothub_pipeline_manual_cb, http_pipeline_manual_cb) @pytest.fixture def connection_string(self, device_connection_string): @@ -855,8 +1129,10 @@ class TestIoTHubDeviceClientInstantiation( IoTHubDeviceClientTestsConfig, SharedClientInstantiationTests ): @pytest.mark.it("Sets on_c2d_message_received handler in the IoTHubPipeline") - def test_sets_on_c2d_message_received_handler_in_pipeline(self, client_class, iothub_pipeline): - client = client_class(iothub_pipeline) + def test_sets_on_c2d_message_received_handler_in_pipeline( + self, client_class, iothub_pipeline, http_pipeline + ): + client = client_class(iothub_pipeline, http_pipeline) assert client._iothub_pipeline.on_c2d_message_received is not None assert ( @@ -864,32 +1140,124 @@ class TestIoTHubDeviceClientInstantiation( == client._inbox_manager.route_c2d_message ) - @pytest.mark.it("Sets the '_edge_pipeline' attribute to None") - def test_edge_pipeline_is_none(self, client_class, iothub_pipeline): - client = client_class(iothub_pipeline) - - assert client._edge_pipeline is None - @pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .create_from_connection_string()") class TestIoTHubDeviceClientCreateFromConnectionString( - IoTHubDeviceClientTestsConfig, SharedClientCreateFromConnectionStringTests + IoTHubDeviceClientTestsConfig, + SharedClientCreateMethodUserOptionTests, + SharedClientCreateFromConnectionStringTests, ): - pass + @pytest.fixture + def client_create_method(self, client_class): + """Provides the specific create method for use in universal tests""" + return client_class.create_from_connection_string + + @pytest.fixture + def create_method_args(self, connection_string): + """Provides the specific create method args for use in universal tests""" + return [connection_string] -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .create_from_shared_access_signature()") -class TestIoTHubDeviceClientCreateFromSharedAccessSignature( - IoTHubDeviceClientTestsConfig, SharedClientCreateFromSharedAccessSignature +@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .create_from_symmetric_key()") +class TestIoTHubDeviceClientCreateFromSymmetricKey( + IoTHubDeviceClientTestsConfig, SharedClientCreateMethodUserOptionTests ): - pass + @pytest.fixture + def client_create_method(self, client_class): + """Provides the specific create method for use in universal tests""" + return client_class.create_from_symmetric_key + + @pytest.fixture + def create_method_args(self, symmetric_key, hostname_fixture, device_id_fixture): + """Provides the specific create method args for use in universal tests""" + return [symmetric_key, hostname_fixture, device_id_fixture] + + @pytest.mark.it("Uses the symmetric key to create a SymmetricKeyAuthenticationProvider") + def test_auth_provider_creation( + self, mocker, client_class, symmetric_key, hostname_fixture, device_id_fixture + ): + mock_auth_init = mocker.patch( + "azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider" + ) + + client_class.create_from_symmetric_key( + symmetric_key=symmetric_key, hostname=hostname_fixture, device_id=device_id_fixture + ) + + assert mock_auth_init.call_count == 1 + assert mock_auth_init.call_args == mocker.call( + hostname=hostname_fixture, + device_id=device_id_fixture, + module_id=None, + shared_access_key=symmetric_key, + ) + + @pytest.mark.it("Uses the SymmetricKeyAuthenticationProvider to create an IoTHubPipeline") + def test_pipeline_creation( + self, + mocker, + client_class, + symmetric_key, + hostname_fixture, + device_id_fixture, + mock_mqtt_pipeline_init, + ): + mock_auth = mocker.patch( + "azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider" + ).return_value + + mock_config_init = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipelineConfig") + + client_class.create_from_symmetric_key( + symmetric_key=symmetric_key, hostname=hostname_fixture, device_id=device_id_fixture + ) + + assert mock_mqtt_pipeline_init.call_count == 1 + assert mock_mqtt_pipeline_init.call_args == mocker.call( + mock_auth, mock_config_init.return_value + ) + + @pytest.mark.it("Uses the IoTHubPipeline to instantiate the client") + def test_client_instantiation( + self, mocker, client_class, symmetric_key, hostname_fixture, device_id_fixture + ): + mock_pipeline = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipeline").return_value + mock_pipeline_http = mocker.patch( + "azure.iot.device.iothub.pipeline.HTTPPipeline" + ).return_value + spy_init = mocker.spy(client_class, "__init__") + client_class.create_from_symmetric_key( + symmetric_key=symmetric_key, hostname=hostname_fixture, device_id=device_id_fixture + ) + assert spy_init.call_count == 1 + assert spy_init.call_args == mocker.call(mocker.ANY, mock_pipeline, mock_pipeline_http) + + @pytest.mark.it("Returns the instantiated client") + def test_returns_client(self, client_class, symmetric_key, hostname_fixture, device_id_fixture): + client = client_class.create_from_symmetric_key( + symmetric_key=symmetric_key, hostname=hostname_fixture, device_id=device_id_fixture + ) + + assert isinstance(client, client_class) @pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .create_from_x509_certificate()") -class TestIoTHubDeviceClientCreateFromX509Certificate(IoTHubDeviceClientTestsConfig): +class TestIoTHubDeviceClientCreateFromX509Certificate( + IoTHubDeviceClientTestsConfig, SharedClientCreateMethodUserOptionTests +): hostname = "durmstranginstitute.farend" device_id = "MySnitch" + @pytest.fixture + def client_create_method(self, client_class): + """Provides the specific create method for use in universal tests""" + return client_class.create_from_x509_certificate + + @pytest.fixture + def create_method_args(self, x509): + """Provides the specific create method args for use in universal tests""" + return [x509, self.hostname, self.device_id] + @pytest.mark.it("Uses the provided arguments to create a X509AuthenticationProvider") def test_auth_provider_creation(self, mocker, client_class, x509): mock_auth_init = mocker.patch("azure.iot.device.iothub.auth.X509AuthenticationProvider") @@ -904,21 +1272,28 @@ class TestIoTHubDeviceClientCreateFromX509Certificate(IoTHubDeviceClientTestsCon ) @pytest.mark.it("Uses the X509AuthenticationProvider to create an IoTHubPipeline") - def test_pipeline_creation(self, mocker, client_class, x509, mock_pipeline_init): + def test_pipeline_creation(self, mocker, client_class, x509, mock_mqtt_pipeline_init): mock_auth = mocker.patch( "azure.iot.device.iothub.auth.X509AuthenticationProvider" ).return_value + mock_config_init = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipelineConfig") + client_class.create_from_x509_certificate( x509=x509, hostname=self.hostname, device_id=self.device_id ) - assert mock_pipeline_init.call_count == 1 - assert mock_pipeline_init.call_args == mocker.call(mock_auth) + assert mock_mqtt_pipeline_init.call_count == 1 + assert mock_mqtt_pipeline_init.call_args == mocker.call( + mock_auth, mock_config_init.return_value + ) @pytest.mark.it("Uses the IoTHubPipeline to instantiate the client") def test_client_instantiation(self, mocker, client_class, x509): mock_pipeline = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipeline").return_value + mock_pipeline_http = mocker.patch( + "azure.iot.device.iothub.pipeline.HTTPPipeline" + ).return_value spy_init = mocker.spy(client_class, "__init__") client_class.create_from_x509_certificate( @@ -926,7 +1301,7 @@ class TestIoTHubDeviceClientCreateFromX509Certificate(IoTHubDeviceClientTestsCon ) assert spy_init.call_count == 1 - assert spy_init.call_args == mocker.call(mocker.ANY, mock_pipeline) + assert spy_init.call_args == mocker.call(mocker.ANY, mock_pipeline, mock_pipeline_http) @pytest.mark.it("Returns the instantiated client") def test_returns_client(self, mocker, client_class, x509): @@ -947,7 +1322,7 @@ class TestIoTHubDeviceClientDisconnect(IoTHubDeviceClientTestsConfig, SharedClie pass -@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - EVENT: Disconnect") +@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - OCCURANCE: Disconnect") class TestIoTHubDeviceClientDisconnectEvent( IoTHubDeviceClientTestsConfig, SharedClientDisconnectEventTests ): @@ -1041,18 +1416,16 @@ class TestIoTHubDeviceClientReceiveC2DMessage(IoTHubDeviceClientTestsConfig): # did not return until after the delay. @pytest.mark.it( - "Raises InboxEmpty exception after a timeout while blocking, in blocking mode with a specified timeout" + "Returns None after a timeout while blocking, in blocking mode with a specified timeout" ) def test_times_out_waiting_for_message_blocking_mode(self, client): - with pytest.raises(InboxEmpty): - client.receive_message(block=True, timeout=0.01) + result = client.receive_message(block=True, timeout=0.01) + assert result is None - @pytest.mark.it( - "Raises InboxEmpty exception immediately if there are no messages, in nonblocking mode" - ) + @pytest.mark.it("Returns None immediately if there are no messages, in nonblocking mode") def test_no_message_in_inbox_nonblocking_mode(self, client): - with pytest.raises(InboxEmpty): - client.receive_message(block=False) + result = client.receive_message(block=False) + assert result is None @pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .receive_method_request()") @@ -1088,6 +1461,155 @@ class TestIoTHubDeviceClientReceiveTwinDesiredPropertiesPatch( pass +@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .get_storage_info_for_blob()") +class TestIoTHubDeviceClientGetStorageInfo(WaitsForEventCompletion, IoTHubDeviceClientTestsConfig): + @pytest.mark.it("Begins a 'get_storage_info_for_blob' HTTPPipeline operation") + def test_calls_pipeline_get_storage_info_for_blob(self, mocker, client, http_pipeline): + fake_blob_name = "__fake_blob_name__" + client.get_storage_info_for_blob(fake_blob_name) + assert http_pipeline.get_storage_info_for_blob.call_count == 1 + assert http_pipeline.get_storage_info_for_blob.call_args == mocker.call( + fake_blob_name, callback=mocker.ANY + ) + + @pytest.mark.it( + "Waits for the completion of the 'get_storage_info_for_blob' pipeline operation before returning" + ) + def test_waits_for_pipeline_op_completion( + self, mocker, client_manual_cb, http_pipeline_manual_cb + ): + fake_blob_name = "__fake_blob_name__" + + self.add_event_completion_checks( + mocker=mocker, + pipeline_function=http_pipeline_manual_cb.get_storage_info_for_blob, + kwargs={"storage_info": "__fake_storage_info__"}, + ) + + client_manual_cb.get_storage_info_for_blob(fake_blob_name) + + @pytest.mark.it( + "Raises a client error if the `get_storage_info_for_blob` pipeline operation calls back with a pipeline error" + ) + @pytest.mark.parametrize( + "pipeline_error,client_error", + [ + pytest.param( + pipeline_exceptions.ProtocolClientError, + client_exceptions.ClientError, + id="ProtocolClientError->ClientError", + ), + pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), + ], + ) + def test_raises_error_on_pipeline_op_error( + self, mocker, client_manual_cb, http_pipeline_manual_cb, pipeline_error, client_error + ): + fake_blob_name = "__fake_blob_name__" + my_pipeline_error = pipeline_error() + self.add_event_completion_checks( + mocker=mocker, + pipeline_function=http_pipeline_manual_cb.get_storage_info_for_blob, + kwargs={"error": my_pipeline_error}, + ) + with pytest.raises(client_error) as e_info: + client_manual_cb.get_storage_info_for_blob(fake_blob_name) + assert e_info.value.__cause__ is my_pipeline_error + + @pytest.mark.it("Returns a storage_info object upon successful completion") + def test_returns_storage_info(self, mocker, client, http_pipeline): + fake_blob_name = "__fake_blob_name__" + fake_storage_info = "__fake_storage_info__" + received_storage_info = client.get_storage_info_for_blob(fake_blob_name) + assert http_pipeline.get_storage_info_for_blob.call_count == 1 + assert http_pipeline.get_storage_info_for_blob.call_args == mocker.call( + fake_blob_name, callback=mocker.ANY + ) + + assert ( + received_storage_info is fake_storage_info + ) # Note: the return value this is checkign for is defined in client_fixtures.py + + +@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .notify_blob_upload_status()") +class TestIoTHubDeviceClientNotifyBlobUploadStatus( + WaitsForEventCompletion, IoTHubDeviceClientTestsConfig +): + @pytest.mark.it("Begins a 'notify_blob_upload_status' HTTPPipeline operation") + def test_calls_pipeline_notify_blob_upload_status(self, client, http_pipeline): + correlation_id = "__fake_correlation_id__" + is_success = "__fake_is_success__" + status_code = "__fake_status_code__" + status_description = "__fake_status_description__" + client.notify_blob_upload_status( + correlation_id, is_success, status_code, status_description + ) + kwargs = http_pipeline.notify_blob_upload_status.call_args[1] + assert http_pipeline.notify_blob_upload_status.call_count == 1 + assert kwargs["correlation_id"] is correlation_id + assert kwargs["is_success"] is is_success + assert kwargs["status_code"] is status_code + assert kwargs["status_description"] is status_description + + @pytest.mark.it( + "Waits for the completion of the 'notify_blob_upload_status' pipeline operation before returning" + ) + def test_waits_for_pipeline_op_completion( + self, mocker, client_manual_cb, http_pipeline_manual_cb + ): + correlation_id = "__fake_correlation_id__" + is_success = "__fake_is_success__" + status_code = "__fake_status_code__" + status_description = "__fake_status_description__" + self.add_event_completion_checks( + mocker=mocker, pipeline_function=http_pipeline_manual_cb.notify_blob_upload_status + ) + + client_manual_cb.notify_blob_upload_status( + correlation_id, is_success, status_code, status_description + ) + + @pytest.mark.it( + "Raises a client error if the `notify_blob_upload_status` pipeline operation calls back with a pipeline error" + ) + @pytest.mark.parametrize( + "pipeline_error,client_error", + [ + pytest.param( + pipeline_exceptions.ProtocolClientError, + client_exceptions.ClientError, + id="ProtocolClientError->ClientError", + ), + pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), + ], + ) + def test_raises_error_on_pipeline_op_error( + self, mocker, client_manual_cb, http_pipeline_manual_cb, pipeline_error, client_error + ): + correlation_id = "__fake_correlation_id__" + is_success = "__fake_is_success__" + status_code = "__fake_status_code__" + status_description = "__fake_status_description__" + my_pipeline_error = pipeline_error() + self.add_event_completion_checks( + mocker=mocker, + pipeline_function=http_pipeline_manual_cb.notify_blob_upload_status, + kwargs={"error": my_pipeline_error}, + ) + with pytest.raises(client_error) as e_info: + client_manual_cb.notify_blob_upload_status( + correlation_id, is_success, status_code, status_description + ) + assert e_info.value.__cause__ is my_pipeline_error + + +@pytest.mark.describe("IoTHubDeviceClient (Synchronous) - PROPERTY .connected") +class TestIoTHubDeviceClientPROPERTYConnected( + IoTHubDeviceClientTestsConfig, SharedClientPROPERTYConnectedTests +): + pass + + ################ # MODULE TESTS # ################ @@ -1097,18 +1619,18 @@ class IoTHubModuleClientTestsConfig(object): return IoTHubModuleClient @pytest.fixture - def client(self, iothub_pipeline): + def client(self, iothub_pipeline, http_pipeline): """This client automatically resolves callbacks sent to the pipeline. It should be used for the majority of tests. """ - return IoTHubModuleClient(iothub_pipeline) + return IoTHubModuleClient(iothub_pipeline, http_pipeline) @pytest.fixture - def client_manual_cb(self, iothub_pipeline_manual_cb): + def client_manual_cb(self, iothub_pipeline_manual_cb, http_pipeline_manual_cb): """This client requires manual triggering of the callbacks sent to the pipeline. It should only be used for tests where manual control fo a callback is required. """ - return IoTHubModuleClient(iothub_pipeline_manual_cb) + return IoTHubModuleClient(iothub_pipeline_manual_cb, http_pipeline_manual_cb) @pytest.fixture def connection_string(self, module_connection_string): @@ -1128,9 +1650,9 @@ class TestIoTHubModuleClientInstantiation( ): @pytest.mark.it("Sets on_input_message_received handler in the IoTHubPipeline") def test_sets_on_input_message_received_handler_in_pipeline( - self, client_class, iothub_pipeline + self, client_class, iothub_pipeline, http_pipeline ): - client = client_class(iothub_pipeline) + client = client_class(iothub_pipeline, http_pipeline) assert client._iothub_pipeline.on_input_message_received is not None assert ( @@ -1138,48 +1660,106 @@ class TestIoTHubModuleClientInstantiation( == client._inbox_manager.route_input_message ) - @pytest.mark.it( - "Stores the EdgePipeline from the optionally-provided 'edge_pipeline' parameter in the '_edge_pipeline' attribute" - ) - def test_sets_edge_pipeline_attribute(self, client_class, iothub_pipeline, edge_pipeline): - client = client_class(iothub_pipeline, edge_pipeline) - - assert client._edge_pipeline is edge_pipeline - - @pytest.mark.it( - "Sets the '_edge_pipeline' attribute to None, if the 'edge_pipeline' parameter is not provided" - ) - def test_edge_pipeline_default_none(self, client_class, iothub_pipeline): - client = client_class(iothub_pipeline) - - assert client._edge_pipeline is None - @pytest.mark.describe("IoTHubModuleClient (Synchronous) - .create_from_connection_string()") class TestIoTHubModuleClientCreateFromConnectionString( - IoTHubModuleClientTestsConfig, SharedClientCreateFromConnectionStringTests + IoTHubModuleClientTestsConfig, + SharedClientCreateMethodUserOptionTests, + SharedClientCreateFromConnectionStringTests, ): - pass + @pytest.fixture + def client_create_method(self, client_class): + """Provides the specific create method for use in universal tests""" + return client_class.create_from_connection_string + + @pytest.fixture + def create_method_args(self, connection_string): + """Provides the specific create method args for use in universal tests""" + return [connection_string] -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - .create_from_shared_access_signature()") -class TestIoTHubModuleClientCreateFromSharedAccessSignature( - IoTHubModuleClientTestsConfig, SharedClientCreateFromSharedAccessSignature +class IoTHubModuleClientClientCreateFromEdgeEnvironmentUserOptionTests( + SharedClientCreateMethodUserOptionTests ): - pass + """This class inherites the user option tests shared by all create method APIs, and overrides + tests in order to accomodate unique requirements for the .create_from_edge_enviornment() method. + + Because .create_from_edge_environment() tests are spread accross multiple test units + (i.e. test classes), these overrides are done in this class, which is then inherited by all + .create_from_edge_environment() test units below. + """ + + @pytest.fixture + def client_create_method(self, client_class): + """Provides the specific create method for use in universal tests""" + return client_class.create_from_edge_environment + + @pytest.fixture + def create_method_args(self): + """Provides the specific create method args for use in universal tests""" + return [] + + @pytest.mark.it( + "Raises a TypeError if the 'server_verification_cert' user option parameter is provided" + ) + def test_server_verification_cert_option( + self, + option_test_required_patching, + client_create_method, + create_method_args, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + """THIS TEST OVERRIDES AN INHERITED TEST""" + + with pytest.raises(TypeError): + client_create_method( + *create_method_args, server_verification_cert="fake_server_verification_cert" + ) + + @pytest.mark.it("Sets default user options if none are provided") + def test_default_options( + self, + option_test_required_patching, + client_create_method, + create_method_args, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + """THIS TEST OVERRIDES AN INHERITED TEST""" + client_create_method(*create_method_args) + + # Get configuration object, and ensure it was used for both protocol pipelines + assert mock_mqtt_pipeline_init.call_count == 1 + config = mock_mqtt_pipeline_init.call_args[0][1] + assert config == mock_http_pipeline_init.call_args[0][1] + + # Get auth provider object, and ensure it was used for both protocol pipelines + auth = mock_mqtt_pipeline_init.call_args[0][0] + assert auth == mock_http_pipeline_init.call_args[0][0] + + assert config.product_info == "" + assert not config.websockets + assert not config.cipher @pytest.mark.describe( "IoTHubModuleClient (Synchronous) - .create_from_edge_environment() -- Edge Container Environment" ) class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithContainerEnv( - IoTHubModuleClientTestsConfig + IoTHubModuleClientTestsConfig, IoTHubModuleClientClientCreateFromEdgeEnvironmentUserOptionTests ): + @pytest.fixture + def option_test_required_patching(self, mocker, edge_container_environment): + """THIS FIXTURE OVERRIDES AN INHERITED FIXTURE""" + mocker.patch("azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider") + mocker.patch.dict(os.environ, edge_container_environment, clear=True) + @pytest.mark.it( "Uses Edge container environment variables to create an IoTEdgeAuthenticationProvider" ) def test_auth_provider_creation(self, mocker, client_class, edge_container_environment): - mocker.patch.dict(os.environ, edge_container_environment) + mocker.patch.dict(os.environ, edge_container_environment, clear=True) mock_auth_init = mocker.patch("azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider") client_class.create_from_edge_environment() @@ -1204,7 +1784,7 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithContainerEnv( # This test verifies that with a hybrid environment, the auth provider will always be # an IoTEdgeAuthenticationProvider, even if local debug variables are present hybrid_environment = merge_dicts(edge_container_environment, edge_local_debug_environment) - mocker.patch.dict(os.environ, hybrid_environment) + mocker.patch.dict(os.environ, hybrid_environment, clear=True) mock_edge_auth_init = mocker.patch( "azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider" ) @@ -1227,33 +1807,39 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithContainerEnv( ) @pytest.mark.it( - "Uses the IoTEdgeAuthenticationProvider to create an IoTHubPipeline and an EdgePipeline" + "Uses the IoTEdgeAuthenticationProvider to create an IoTHubPipeline and an HTTPPipeline" ) def test_pipeline_creation(self, mocker, client_class, edge_container_environment): - mocker.patch.dict(os.environ, edge_container_environment) + mocker.patch.dict(os.environ, edge_container_environment, clear=True) mock_auth = mocker.patch( "azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider" ).return_value + mock_config_init = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipelineConfig") mock_iothub_pipeline_init = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipeline") - mock_edge_pipeline_init = mocker.patch("azure.iot.device.iothub.pipeline.EdgePipeline") + mock_http_pipeline_init = mocker.patch("azure.iot.device.iothub.pipeline.HTTPPipeline") client_class.create_from_edge_environment() assert mock_iothub_pipeline_init.call_count == 1 - assert mock_iothub_pipeline_init.call_args == mocker.call(mock_auth) - assert mock_edge_pipeline_init.call_count == 1 - assert mock_edge_pipeline_init.call_args == mocker.call(mock_auth) + assert mock_iothub_pipeline_init.call_args == mocker.call( + mock_auth, mock_config_init.return_value + ) + assert mock_http_pipeline_init.call_count == 1 + # This asserts without mock_config_init because currently edge isn't implemented. When it is, this should be identical to the line aboe. + assert mock_http_pipeline_init.call_args == mocker.call( + mock_auth, mock_config_init.return_value + ) - @pytest.mark.it("Uses the IoTHubPipeline and the EdgePipeline to instantiate the client") + @pytest.mark.it("Uses the IoTHubPipeline and the HTTPPipeline to instantiate the client") def test_client_instantiation(self, mocker, client_class, edge_container_environment): - mocker.patch.dict(os.environ, edge_container_environment) + mocker.patch.dict(os.environ, edge_container_environment, clear=True) # Always patch the IoTEdgeAuthenticationProvider to prevent I/O operations mocker.patch("azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider") mock_iothub_pipeline = mocker.patch( "azure.iot.device.iothub.pipeline.IoTHubPipeline" ).return_value - mock_edge_pipeline = mocker.patch( - "azure.iot.device.iothub.pipeline.EdgePipeline" + mock_http_pipeline = mocker.patch( + "azure.iot.device.iothub.pipeline.HTTPPipeline" ).return_value spy_init = mocker.spy(client_class, "__init__") @@ -1261,12 +1847,12 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithContainerEnv( assert spy_init.call_count == 1 assert spy_init.call_args == mocker.call( - mocker.ANY, mock_iothub_pipeline, edge_pipeline=mock_edge_pipeline + mocker.ANY, mock_iothub_pipeline, mock_http_pipeline ) @pytest.mark.it("Returns the instantiated client") def test_returns_client(self, mocker, client_class, edge_container_environment): - mocker.patch.dict(os.environ, edge_container_environment) + mocker.patch.dict(os.environ, edge_container_environment, clear=True) # Always patch the IoTEdgeAuthenticationProvider to prevent I/O operations mocker.patch("azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider") @@ -1274,7 +1860,7 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithContainerEnv( assert isinstance(client, client_class) - @pytest.mark.it("Raises IoTEdgeError if the environment is missing required variables") + @pytest.mark.it("Raises OSError if the environment is missing required variables") @pytest.mark.parametrize( "missing_env_var", [ @@ -1292,35 +1878,48 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithContainerEnv( ): # Remove a variable from the fixture del edge_container_environment[missing_env_var] - mocker.patch.dict(os.environ, edge_container_environment) + mocker.patch.dict(os.environ, edge_container_environment, clear=True) - with pytest.raises(IoTEdgeError): + with pytest.raises(OSError): client_class.create_from_edge_environment() - @pytest.mark.it("Raises IoTEdgeError if there is an error using the Edge for authentication") + @pytest.mark.it("Raises OSError if there is an error using the Edge for authentication") def test_bad_edge_auth(self, mocker, client_class, edge_container_environment): - mocker.patch.dict(os.environ, edge_container_environment) + mocker.patch.dict(os.environ, edge_container_environment, clear=True) mock_auth = mocker.patch("azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider") - mock_auth.side_effect = IoTEdgeError + my_edge_error = IoTEdgeError() + mock_auth.side_effect = my_edge_error - with pytest.raises(IoTEdgeError): + with pytest.raises(OSError) as e_info: client_class.create_from_edge_environment() + assert e_info.value.__cause__ is my_edge_error @pytest.mark.describe( "IoTHubModuleClient (Synchronous) - .create_from_edge_environment() -- Edge Local Debug Environment" ) -class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnv(IoTHubModuleClientTestsConfig): +class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnv( + IoTHubModuleClientTestsConfig, IoTHubModuleClientClientCreateFromEdgeEnvironmentUserOptionTests +): + @pytest.fixture + def option_test_required_patching(self, mocker, edge_local_debug_environment): + """THIS FIXTURE OVERRIDES AN INHERITED FIXTURE""" + mocker.patch("azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider") + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) + mocker.patch.object(io, "open") + @pytest.fixture def mock_open(self, mocker): return mocker.patch.object(io, "open") @pytest.mark.it( - "Extracts the CA certificate from the file indicated by the EdgeModuleCACertificateFile environment variable" + "Extracts the server verification certificate from the file indicated by the EdgeModuleCACertificateFile environment variable" ) - def test_read_ca_cert(self, mocker, client_class, edge_local_debug_environment, mock_open): + def test_read_server_verification_cert( + self, mocker, client_class, edge_local_debug_environment, mock_open + ): mock_file_handle = mock_open.return_value.__enter__.return_value - mocker.patch.dict(os.environ, edge_local_debug_environment) + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) client_class.create_from_edge_environment() assert mock_open.call_count == 1 assert mock_open.call_args == mocker.call( @@ -1329,13 +1928,13 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnv(IoTHubModuleCl assert mock_file_handle.read.call_count == 1 @pytest.mark.it( - "Uses Edge local debug environment variables to create a SymmetricKeyAuthenticationProvider (with CA cert)" + "Uses Edge local debug environment variables to create a SymmetricKeyAuthenticationProvider (with server verification cert)" ) def test_auth_provider_creation( self, mocker, client_class, edge_local_debug_environment, mock_open ): expected_cert = mock_open.return_value.__enter__.return_value.read.return_value - mocker.patch.dict(os.environ, edge_local_debug_environment) + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) mock_auth_parse = mocker.patch( "azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider" ).parse @@ -1346,7 +1945,7 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnv(IoTHubModuleCl assert mock_auth_parse.call_args == mocker.call( edge_local_debug_environment["EdgeHubConnectionString"] ) - assert mock_auth_parse.return_value.ca_cert == expected_cert + assert mock_auth_parse.return_value.server_verification_cert == expected_cert @pytest.mark.it( "Only uses Edge local debug variables if no Edge container variables are present in the environment" @@ -1362,7 +1961,7 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnv(IoTHubModuleCl # This test verifies that with a hybrid environment, the auth provider will always be # an IoTEdgeAuthenticationProvider, even if local debug variables are present hybrid_environment = merge_dicts(edge_container_environment, edge_local_debug_environment) - mocker.patch.dict(os.environ, hybrid_environment) + mocker.patch.dict(os.environ, hybrid_environment, clear=True) mock_edge_auth_init = mocker.patch( "azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider" ) @@ -1385,33 +1984,38 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnv(IoTHubModuleCl ) @pytest.mark.it( - "Uses the SymmetricKeyAuthenticationProvider to create an IoTHubPipeline and an EdgePipeline" + "Uses the SymmetricKeyAuthenticationProvider to create an IoTHubPipeline and an HTTPPipeline" ) def test_pipeline_creation(self, mocker, client_class, edge_local_debug_environment, mock_open): - mocker.patch.dict(os.environ, edge_local_debug_environment) + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) mock_auth = mocker.patch( "azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider" ).parse.return_value + mock_config_init = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipelineConfig") mock_iothub_pipeline_init = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipeline") - mock_edge_pipeline_init = mocker.patch("azure.iot.device.iothub.pipeline.EdgePipeline") + mock_http_pipeline_init = mocker.patch("azure.iot.device.iothub.pipeline.HTTPPipeline") client_class.create_from_edge_environment() assert mock_iothub_pipeline_init.call_count == 1 - assert mock_iothub_pipeline_init.call_args == mocker.call(mock_auth) - assert mock_edge_pipeline_init.call_count == 1 - assert mock_iothub_pipeline_init.call_args == mocker.call(mock_auth) + assert mock_iothub_pipeline_init.call_args == mocker.call( + mock_auth, mock_config_init.return_value + ) + assert mock_http_pipeline_init.call_count == 1 + assert mock_http_pipeline_init.call_args == mocker.call( + mock_auth, mock_config_init.return_value + ) - @pytest.mark.it("Uses the IoTHubPipeline and the EdgePipeline to instantiate the client") + @pytest.mark.it("Uses the IoTHubPipeline and the HTTPPipeline to instantiate the client") def test_client_instantiation( self, mocker, client_class, edge_local_debug_environment, mock_open ): - mocker.patch.dict(os.environ, edge_local_debug_environment) + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) mock_iothub_pipeline = mocker.patch( "azure.iot.device.iothub.pipeline.IoTHubPipeline" ).return_value - mock_edge_pipeline = mocker.patch( - "azure.iot.device.iothub.pipeline.EdgePipeline" + mock_http_pipeline = mocker.patch( + "azure.iot.device.iothub.pipeline.HTTPPipeline" ).return_value spy_init = mocker.spy(client_class, "__init__") @@ -1419,18 +2023,18 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnv(IoTHubModuleCl assert spy_init.call_count == 1 assert spy_init.call_args == mocker.call( - mocker.ANY, mock_iothub_pipeline, edge_pipeline=mock_edge_pipeline + mocker.ANY, mock_iothub_pipeline, mock_http_pipeline ) @pytest.mark.it("Returns the instantiated client") def test_returns_client(self, mocker, client_class, edge_local_debug_environment, mock_open): - mocker.patch.dict(os.environ, edge_local_debug_environment) + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) client = client_class.create_from_edge_environment() assert isinstance(client, client_class) - @pytest.mark.it("Raises IoTEdgeError if the environment is missing required variables") + @pytest.mark.it("Raises OSError if the environment is missing required variables") @pytest.mark.parametrize( "missing_env_var", ["EdgeHubConnectionString", "EdgeModuleCACertificateFile"] ) @@ -1439,9 +2043,9 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnv(IoTHubModuleCl ): # Remove a variable from the fixture del edge_local_debug_environment[missing_env_var] - mocker.patch.dict(os.environ, edge_local_debug_environment) + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - with pytest.raises(IoTEdgeError): + with pytest.raises(OSError): client_class.create_from_edge_environment() # TODO: If auth package was refactored to use ConnectionString class, tests from that @@ -1465,7 +2069,7 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnv(IoTHubModuleCl self, mocker, client_class, edge_local_debug_environment, bad_cs, mock_open ): edge_local_debug_environment["EdgeHubConnectionString"] = bad_cs - mocker.patch.dict(os.environ, edge_local_debug_environment) + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) with pytest.raises(ValueError): client_class.create_from_edge_environment() @@ -1480,10 +2084,12 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnv(IoTHubModuleCl except NameError: FileNotFoundError = IOError - mocker.patch.dict(os.environ, edge_local_debug_environment) - mock_open.side_effect = FileNotFoundError - with pytest.raises(ValueError): + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) + my_fnf_error = FileNotFoundError() + mock_open.side_effect = my_fnf_error + with pytest.raises(ValueError) as e_info: client_class.create_from_edge_environment() + assert e_info.value.__cause__ is my_fnf_error @pytest.mark.it( "Raises ValueError if the file referenced by the filepath in the EdgeModuleCACertificateFile environment variable cannot be opened" @@ -1491,21 +2097,34 @@ class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnv(IoTHubModuleCl def test_bad_file_io(self, mocker, client_class, edge_local_debug_environment, mock_open): # Raise a different error in Python 2 vs 3 if six.PY2: - error = IOError + error = IOError() else: - error = OSError - mocker.patch.dict(os.environ, edge_local_debug_environment) + error = OSError() + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) mock_open.side_effect = error - with pytest.raises(ValueError): + with pytest.raises(ValueError) as e_info: client_class.create_from_edge_environment() + assert e_info.value.__cause__ is error @pytest.mark.describe("IoTHubModuleClient (Synchronous) - .create_from_x509_certificate()") -class TestIoTHubModuleClientCreateFromX509Certificate(IoTHubModuleClientTestsConfig): +class TestIoTHubModuleClientCreateFromX509Certificate( + IoTHubModuleClientTestsConfig, SharedClientCreateMethodUserOptionTests +): hostname = "durmstranginstitute.farend" device_id = "MySnitch" module_id = "Charms" + @pytest.fixture + def client_create_method(self, client_class): + """Provides the specific create method for use in universal tests""" + return client_class.create_from_x509_certificate + + @pytest.fixture + def create_method_args(self, x509): + """Provides the specific create method args for use in universal tests""" + return [x509, self.hostname, self.device_id, self.module_id] + @pytest.mark.it("Uses the provided arguments to create a X509AuthenticationProvider") def test_auth_provider_creation(self, mocker, client_class, x509): mock_auth_init = mocker.patch("azure.iot.device.iothub.auth.X509AuthenticationProvider") @@ -1520,21 +2139,29 @@ class TestIoTHubModuleClientCreateFromX509Certificate(IoTHubModuleClientTestsCon ) @pytest.mark.it("Uses the X509AuthenticationProvider to create an IoTHubPipeline") - def test_pipeline_creation(self, mocker, client_class, x509, mock_pipeline_init): + def test_pipeline_creation(self, mocker, client_class, x509, mock_mqtt_pipeline_init): mock_auth = mocker.patch( "azure.iot.device.iothub.auth.X509AuthenticationProvider" ).return_value + mock_config_init = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipelineConfig") + client_class.create_from_x509_certificate( x509=x509, hostname=self.hostname, device_id=self.device_id, module_id=self.module_id ) - assert mock_pipeline_init.call_count == 1 - assert mock_pipeline_init.call_args == mocker.call(mock_auth) + assert mock_mqtt_pipeline_init.call_count == 1 + assert mock_mqtt_pipeline_init.call_args == mocker.call( + mock_auth, mock_config_init.return_value + ) @pytest.mark.it("Uses the IoTHubPipeline to instantiate the client") def test_client_instantiation(self, mocker, client_class, x509): mock_pipeline = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipeline").return_value + mock_pipeline_http = mocker.patch( + "azure.iot.device.iothub.pipeline.HTTPPipeline" + ).return_value + spy_init = mocker.spy(client_class, "__init__") client_class.create_from_x509_certificate( @@ -1542,7 +2169,7 @@ class TestIoTHubModuleClientCreateFromX509Certificate(IoTHubModuleClientTestsCon ) assert spy_init.call_count == 1 - assert spy_init.call_args == mocker.call(mocker.ANY, mock_pipeline) + assert spy_init.call_args == mocker.call(mocker.ANY, mock_pipeline, mock_pipeline_http) @pytest.mark.it("Returns the instantiated client") def test_returns_client(self, mocker, client_class, x509): @@ -1563,7 +2190,7 @@ class TestIoTHubModuleClientDisconnect(IoTHubModuleClientTestsConfig, SharedClie pass -@pytest.mark.describe("IoTHubModuleClient (Synchronous) - EVENT: Disconnect") +@pytest.mark.describe("IoTHubModuleClient (Synchronous) - OCCURANCE: Disconnect") class TestIoTHubModuleClientDisconnectEvent( IoTHubModuleClientTestsConfig, SharedClientDisconnectEventTests ): @@ -1600,20 +2227,53 @@ class TestIoTHubModuleClientSendToOutput(IoTHubModuleClientTestsConfig, WaitsFor client_manual_cb.send_message_to_output(message, output_name) @pytest.mark.it( - "Raises an error if the `send_out_event` pipeline operation calls back with an error" + "Raises a client error if the `send_out_event` pipeline operation calls back with a pipeline error" + ) + @pytest.mark.parametrize( + "pipeline_error,client_error", + [ + pytest.param( + pipeline_exceptions.ConnectionDroppedError, + client_exceptions.ConnectionDroppedError, + id="ConnectionDroppedError->ConnectionDroppedError", + ), + pytest.param( + pipeline_exceptions.ConnectionFailedError, + client_exceptions.ConnectionFailedError, + id="ConnectionFailedError->ConnectionFailedError", + ), + pytest.param( + pipeline_exceptions.UnauthorizedError, + client_exceptions.CredentialError, + id="UnauthorizedError->CredentialError", + ), + pytest.param( + pipeline_exceptions.ProtocolClientError, + client_exceptions.ClientError, + id="ProtocolClientError->ClientError", + ), + pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), + ], ) def test_raises_error_on_pipeline_op_error( - self, mocker, client_manual_cb, iothub_pipeline_manual_cb, message, fake_error + self, + mocker, + client_manual_cb, + iothub_pipeline_manual_cb, + message, + pipeline_error, + client_error, ): + my_pipeline_error = pipeline_error() self.add_event_completion_checks( mocker=mocker, pipeline_function=iothub_pipeline_manual_cb.send_output_event, - kwargs={"error": fake_error}, + kwargs={"error": my_pipeline_error}, ) output_name = "some_output" - with pytest.raises(fake_error.__class__) as e_info: + with pytest.raises(client_error) as e_info: client_manual_cb.send_message_to_output(message, output_name) - assert e_info.value is fake_error + assert e_info.value.__cause__ is my_pipeline_error @pytest.mark.it( "Wraps 'message' input parameter in Message object if it is not a Message object" @@ -1639,6 +2299,49 @@ class TestIoTHubModuleClientSendToOutput(IoTHubModuleClientTestsConfig, WaitsFor assert isinstance(sent_message, Message) assert sent_message.data == message_input + @pytest.mark.it("Raises error when message data size is greater than 256 KB") + def test_raises_error_when_message_to_output_data_greater_than_256( + self, client, iothub_pipeline + ): + output_name = "some_output" + data_input = "serpensortia" * 256000 + message = Message(data_input) + with pytest.raises(ValueError) as e_info: + client.send_message_to_output(message, output_name) + assert "256 KB" in e_info.value.args[0] + assert iothub_pipeline.send_output_event.call_count == 0 + + @pytest.mark.it("Raises error when message size is greater than 256 KB") + def test_raises_error_when_message_to_output_size_greater_than_256( + self, client, iothub_pipeline + ): + output_name = "some_output" + data_input = "serpensortia" + message = Message(data_input) + message.custom_properties["spell"] = data_input * 256000 + with pytest.raises(ValueError) as e_info: + client.send_message_to_output(message, output_name) + assert "256 KB" in e_info.value.args[0] + assert iothub_pipeline.send_output_event.call_count == 0 + + @pytest.mark.it("Does not raises error when message data size is equal to 256 KB") + def test_raises_error_when_message_to_output_data_equal_to_256(self, client, iothub_pipeline): + output_name = "some_output" + data_input = "a" * 262095 + message = Message(data_input) + # This check was put as message class may undergo the default content type encoding change + # and the above calculation will change. + # Had to do greater than check for python 2. Ideally should be not equal check + if message.get_size() > device_constant.TELEMETRY_MESSAGE_SIZE_LIMIT: + assert False + + client.send_message_to_output(message, output_name) + + assert iothub_pipeline.send_output_event.call_count == 1 + sent_message = iothub_pipeline.send_output_event.call_args[0][0] + assert isinstance(sent_message, Message) + assert sent_message.data == data_input + @pytest.mark.describe("IoTHubModuleClient (Synchronous) - .receive_message_on_input()") class TestIoTHubModuleClientReceiveInputMessage(IoTHubModuleClientTestsConfig): @@ -1737,20 +2440,18 @@ class TestIoTHubModuleClientReceiveInputMessage(IoTHubModuleClientTestsConfig): # did not return until after the delay. @pytest.mark.it( - "Raises InboxEmpty exception after a timeout while blocking, in blocking mode with a specified timeout" + "Returns None after a timeout while blocking, in blocking mode with a specified timeout" ) def test_times_out_waiting_for_message_blocking_mode(self, client): input_name = "some_input" - with pytest.raises(InboxEmpty): - client.receive_message_on_input(input_name, block=True, timeout=0.01) + result = client.receive_message_on_input(input_name, block=True, timeout=0.01) + assert result is None - @pytest.mark.it( - "Raises InboxEmpty exception immediately if there are no messages, in nonblocking mode" - ) + @pytest.mark.it("Returns None immediately if there are no messages, in nonblocking mode") def test_no_message_in_inbox_nonblocking_mode(self, client): input_name = "some_input" - with pytest.raises(InboxEmpty): - client.receive_message_on_input(input_name, block=False) + result = client.receive_message_on_input(input_name, block=False) + assert result is None @pytest.mark.describe("IoTHubModuleClient (Synchronous) - .receive_method_request()") @@ -1786,6 +2487,83 @@ class TestIoTHubModuleClientReceiveTwinDesiredPropertiesPatch( pass +@pytest.mark.describe("IoTHubModuleClient (Synchronous) - .invoke_method()") +class TestIoTHubModuleClientInvokeMethod(WaitsForEventCompletion, IoTHubModuleClientTestsConfig): + @pytest.mark.it("Begins a 'invoke_method' HTTPPipeline operation where the target is a device") + def test_calls_pipeline_invoke_method_for_device(self, client, http_pipeline): + method_params = "__fake_method_params__" + device_id = "__fake_device_id__" + client.invoke_method(method_params, device_id) + assert http_pipeline.invoke_method.call_count == 1 + assert http_pipeline.invoke_method.call_args[0][0] is device_id + assert http_pipeline.invoke_method.call_args[0][1] is method_params + + @pytest.mark.it("Begins a 'invoke_method' HTTPPipeline operation where the target is a module") + def test_calls_pipeline_invoke_method_for_module(self, client, http_pipeline): + method_params = "__fake_method_params__" + device_id = "__fake_device_id__" + module_id = "__fake_module_id__" + client.invoke_method(method_params, device_id, module_id=module_id) + assert http_pipeline.invoke_method.call_count == 1 + assert http_pipeline.invoke_method.call_args[0][0] is device_id + assert http_pipeline.invoke_method.call_args[0][1] is method_params + assert http_pipeline.invoke_method.call_args[1]["module_id"] is module_id + + @pytest.mark.it( + "Waits for the completion of the 'invoke_method' pipeline operation before returning" + ) + def test_waits_for_pipeline_op_completion( + self, mocker, client_manual_cb, http_pipeline_manual_cb + ): + method_params = "__fake_method_params__" + device_id = "__fake_device_id__" + module_id = "__fake_module_id__" + self.add_event_completion_checks( + mocker=mocker, + pipeline_function=http_pipeline_manual_cb.invoke_method, + kwargs={"invoke_method_response": "__fake_invoke_method_response__"}, + ) + + client_manual_cb.invoke_method(method_params, device_id, module_id=module_id) + + @pytest.mark.it( + "Raises a client error if the `invoke_method` pipeline operation calls back with a pipeline error" + ) + @pytest.mark.parametrize( + "pipeline_error,client_error", + [ + pytest.param( + pipeline_exceptions.ProtocolClientError, + client_exceptions.ClientError, + id="ProtocolClientError->ClientError", + ), + pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), + ], + ) + def test_raises_error_on_pipeline_op_error( + self, mocker, client_manual_cb, http_pipeline_manual_cb, pipeline_error, client_error + ): + method_params = "__fake_method_params__" + device_id = "__fake_device_id__" + module_id = "__fake_module_id__" + my_pipeline_error = pipeline_error() + self.add_event_completion_checks( + mocker=mocker, + pipeline_function=http_pipeline_manual_cb.invoke_method, + kwargs={"error": my_pipeline_error}, + ) + with pytest.raises(client_error) as e_info: + client_manual_cb.invoke_method(method_params, device_id, module_id=module_id) + assert e_info.value.__cause__ is my_pipeline_error + + +@pytest.mark.describe("IoTHubModule (Synchronous) - PROPERTY .connected") +class TestIoTHubModuleClientPROPERTYConnected( + IoTHubModuleClientTestsConfig, SharedClientPROPERTYConnectedTests +): + pass + + #################### # HELPER FUNCTIONS # #################### diff --git a/azure-iot-device/tests/provisioning/aio/test_async_provisioning_device_client.py b/azure-iot-device/tests/provisioning/aio/test_async_provisioning_device_client.py index 8f7c30bf4..93f3414c2 100644 --- a/azure-iot-device/tests/provisioning/aio/test_async_provisioning_device_client.py +++ b/azure-iot-device/tests/provisioning/aio/test_async_provisioning_device_client.py @@ -5,7 +5,6 @@ # -------------------------------------------------------------------------- import pytest import logging -from azure.iot.device.provisioning.internal.polling_machine import PollingMachine from azure.iot.device.provisioning.aio.async_provisioning_device_client import ( ProvisioningDeviceClient, ) @@ -13,8 +12,12 @@ from azure.iot.device.provisioning.models.registration_result import ( RegistrationResult, RegistrationState, ) +from azure.iot.device.provisioning import security, pipeline from azure.iot.device.common.models.x509 import X509 -from azure.iot.device.provisioning.pipeline import pipeline_ops_provisioning +from azure.iot.device.common import async_adapter +import asyncio +from azure.iot.device.iothub.pipeline import exceptions as pipeline_exceptions +from azure.iot.device import exceptions as client_exceptions logging.basicConfig(level=logging.DEBUG) pytestmark = pytest.mark.asyncio @@ -33,143 +36,438 @@ fake_request_id = "request_1234" fake_device_id = "MyNimbus2000" fake_assigned_hub = "Dumbledore'sArmy" -fake_registration_state = RegistrationState(fake_device_id, fake_assigned_hub, fake_sub_status) + +async def create_completed_future(result=None): + f = asyncio.Future() + f.set_result(result) + return f -def create_success_result(): - return RegistrationResult( - fake_request_id, fake_operation_id, fake_status, fake_registration_state - ) +@pytest.fixture +def registration_result(): + registration_state = RegistrationState(fake_device_id, fake_assigned_hub, fake_sub_status) + return RegistrationResult(fake_operation_id, fake_status, registration_state) -def create_error(): - return RuntimeError("Incoming Failure") - - -def fake_x509(): +@pytest.fixture +def x509(): return X509(fake_x509_cert_file_value, fake_x509_cert_key_file, fake_pass_phrase) -# automatically mock the transport for all tests in this file. @pytest.fixture(autouse=True) -def mock_transport(mocker): - mocker.patch( - "azure.iot.device.common.pipeline.pipeline_stages_mqtt.MQTTTransport", autospec=True +def provisioning_pipeline(mocker): + return mocker.MagicMock(wraps=FakeProvisioningPipeline()) + + +class FakeProvisioningPipeline: + def __init__(self): + self.responses_enabled = {} + + def connect(self, callback): + callback() + + def disconnect(self, callback): + callback() + + def enable_responses(self, callback): + callback() + + def register(self, payload, callback): + callback(result={}) + + +# automatically mock the pipeline for all tests in this file +@pytest.fixture(autouse=True) +def mock_pipeline_init(mocker): + return mocker.patch("azure.iot.device.provisioning.pipeline.ProvisioningPipeline") + + +class SharedClientCreateMethodUserOptionTests(object): + @pytest.mark.it( + "Sets the 'websockets' user option parameter on the PipelineConfig, if provided" ) + async def test_websockets_option( + self, mocker, client_create_method, create_method_args, mock_pipeline_init + ): + client_create_method(*create_method_args, websockets=True) + + # Get configuration object + assert mock_pipeline_init.call_count == 1 + config = mock_pipeline_init.call_args[0][1] + + assert config.websockets + + # TODO: Show that input in the wrong format is formatted to the correct one. This test exists + # in the ProvisioningPipelineConfig object already, but we do not currently show that this is felt + # from the API level. + @pytest.mark.it("Sets the 'cipher' user option parameter on the PipelineConfig, if provided") + async def test_cipher_option( + self, mocker, client_create_method, create_method_args, mock_pipeline_init + ): + + cipher = "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256" + client_create_method(*create_method_args, cipher=cipher) + + # Get configuration object + assert mock_pipeline_init.call_count == 1 + config = mock_pipeline_init.call_args[0][1] + + assert config.cipher == cipher + + @pytest.mark.it("Raises a TypeError if an invalid user option parameter is provided") + async def test_invalid_option( + self, mocker, client_create_method, create_method_args, mock_pipeline_init + ): + with pytest.raises(TypeError): + client_create_method(*create_method_args, invalid_option="some_value") + + @pytest.mark.it("Sets default user options if none are provided") + async def test_default_options( + self, mocker, client_create_method, create_method_args, mock_pipeline_init + ): + client_create_method(*create_method_args) + + # Get configuration object + assert mock_pipeline_init.call_count == 1 + config = mock_pipeline_init.call_args[0][1] + + assert not config.websockets + assert not config.cipher -@pytest.mark.describe("ProvisioningDeviceClient - Init") -class TestClientCreate(object): - xfail_notimplemented = pytest.mark.xfail(raises=NotImplementedError, reason="Unimplemented") - - @pytest.mark.it("Is created from a symmetric key and protocol") - @pytest.mark.parametrize( - "protocol", - [ - pytest.param("mqtt", id="mqtt"), - pytest.param(None, id="optional protocol"), - pytest.param("amqp", id="amqp", marks=xfail_notimplemented), - pytest.param("http", id="http", marks=xfail_notimplemented), - ], +@pytest.mark.describe("ProvisioningDeviceClient - Instantiation") +class TestClientInstantiation(object): + @pytest.mark.it( + "Stores the ProvisioningPipeline from the 'provisioning_pipeline' parameter in the '_provisioning_pipeline' attribute" ) - async def test_create_from_symmetric_key(self, mocker, protocol): + async def test_sets_provisioning_pipeline(self, provisioning_pipeline): + client = ProvisioningDeviceClient(provisioning_pipeline) + + assert client._provisioning_pipeline is provisioning_pipeline + + @pytest.mark.it( + "Instantiates with the initial value of the '_provisioning_payload' attribute set to None" + ) + async def test_payload(self, provisioning_pipeline): + client = ProvisioningDeviceClient(provisioning_pipeline) + + assert client._provisioning_payload is None + + +@pytest.mark.describe("ProvisioningDeviceClient - .create_from_symmetric_key()") +class TestClientCreateFromSymmetricKey(SharedClientCreateMethodUserOptionTests): + @pytest.fixture + async def client_create_method(self): + return ProvisioningDeviceClient.create_from_symmetric_key + + @pytest.fixture + async def create_method_args(self): + return [fake_provisioning_host, fake_registration_id, fake_id_scope, fake_symmetric_key] + + @pytest.mark.it("Creates a SymmetricKeySecurityClient using the given parameters") + async def test_security_client(self, mocker): + spy_sec_client = mocker.spy(security, "SymmetricKeySecurityClient") + + ProvisioningDeviceClient.create_from_symmetric_key( + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + symmetric_key=fake_symmetric_key, + ) + + assert spy_sec_client.call_count == 1 + assert spy_sec_client.call_args == mocker.call( + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + symmetric_key=fake_symmetric_key, + ) + + @pytest.mark.it( + "Uses the SymmetricKeySecurityClient object and the ProvisioningPipelineConfig object to create a ProvisioningPipeline" + ) + async def test_pipeline(self, mocker, mock_pipeline_init): + # Note that the details of how the pipeline config is set up are covered in the + # SharedClientCreateMethodUserOptionTests + mock_pipeline_config = mocker.patch.object( + pipeline, "ProvisioningPipelineConfig" + ).return_value + mock_sec_client = mocker.patch.object(security, "SymmetricKeySecurityClient").return_value + + ProvisioningDeviceClient.create_from_symmetric_key( + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + symmetric_key=fake_symmetric_key, + ) + + assert mock_pipeline_init.call_count == 1 + assert mock_pipeline_init.call_args == mocker.call(mock_sec_client, mock_pipeline_config) + + @pytest.mark.it("Uses the ProvisioningPipeline to instantiate the client") + async def test_client_creation(self, mocker, mock_pipeline_init): + spy_client_init = mocker.spy(ProvisioningDeviceClient, "__init__") + + ProvisioningDeviceClient.create_from_symmetric_key( + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + symmetric_key=fake_symmetric_key, + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(mocker.ANY, mock_pipeline_init.return_value) + + @pytest.mark.it("Returns the instantiated client") + async def test_returns_client(self, mocker): client = ProvisioningDeviceClient.create_from_symmetric_key( - fake_provisioning_host, fake_symmetric_key, fake_registration_id, fake_id_scope + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + symmetric_key=fake_symmetric_key, ) assert isinstance(client, ProvisioningDeviceClient) - assert client._provisioning_pipeline is not None - @pytest.mark.it("Is created from a x509 certificate key and protocol") + +@pytest.mark.describe("ProvisioningDeviceClient - .create_from_x509_certificate()") +class TestClientCreateFromX509Certificate(SharedClientCreateMethodUserOptionTests): + @pytest.fixture + def client_create_method(self): + return ProvisioningDeviceClient.create_from_x509_certificate + + @pytest.fixture + def create_method_args(self, x509): + return [fake_provisioning_host, fake_registration_id, fake_id_scope, x509] + + @pytest.mark.it("Creates an X509SecurityClient using the given parameters") + async def test_security_client(self, mocker, x509): + spy_sec_client = mocker.spy(security, "X509SecurityClient") + + ProvisioningDeviceClient.create_from_x509_certificate( + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + x509=x509, + ) + + assert spy_sec_client.call_count == 1 + assert spy_sec_client.call_args == mocker.call( + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + x509=x509, + ) + + @pytest.mark.it( + "Uses the X509SecurityClient object and the ProvisioningPipelineConfig object to create a ProvisioningPipeline" + ) + async def test_pipeline(self, mocker, mock_pipeline_init, x509): + # Note that the details of how the pipeline config is set up are covered in the + # SharedClientCreateMethodUserOptionTests + mock_pipeline_config = mocker.patch.object( + pipeline, "ProvisioningPipelineConfig" + ).return_value + mock_sec_client = mocker.patch.object(security, "X509SecurityClient").return_value + + ProvisioningDeviceClient.create_from_x509_certificate( + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + x509=x509, + ) + + assert mock_pipeline_init.call_count == 1 + assert mock_pipeline_init.call_args == mocker.call(mock_sec_client, mock_pipeline_config) + + @pytest.mark.it("Uses the ProvisioningPipeline to instantiate the client") + async def test_client_creation(self, mocker, mock_pipeline_init, x509): + spy_client_init = mocker.spy(ProvisioningDeviceClient, "__init__") + + ProvisioningDeviceClient.create_from_x509_certificate( + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + x509=x509, + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(mocker.ANY, mock_pipeline_init.return_value) + + @pytest.mark.it("Returns the instantiated client") + async def test_returns_client(self, mocker, x509): + client = ProvisioningDeviceClient.create_from_x509_certificate( + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + x509=x509, + ) + assert isinstance(client, ProvisioningDeviceClient) + + +@pytest.mark.describe("ProvisioningDeviceClient - .register()") +class TestClientRegister(object): + @pytest.mark.it("Implicitly enables responses from provisioning service if not already enabled") + async def test_enables_provisioning_only_if_not_already_enabled( + self, mocker, provisioning_pipeline, registration_result + ): + # Override callback to pass successful result + def register_complete_success_callback(payload, callback): + callback(result=registration_result) + + mocker.patch.object( + provisioning_pipeline, "register", side_effect=register_complete_success_callback + ) + + provisioning_pipeline.responses_enabled.__getitem__.return_value = False + + client = ProvisioningDeviceClient(provisioning_pipeline) + await client.register() + + assert provisioning_pipeline.enable_responses.call_count == 1 + + provisioning_pipeline.enable_responses.reset_mock() + + provisioning_pipeline.responses_enabled.__getitem__.return_value = True + await client.register() + assert provisioning_pipeline.enable_responses.call_count == 0 + + @pytest.mark.it("Begins a 'register' pipeline operation") + async def test_register_calls_pipeline_register( + self, provisioning_pipeline, mocker, registration_result + ): + def register_complete_success_callback(payload, callback): + callback(result=registration_result) + + mocker.patch.object( + provisioning_pipeline, "register", side_effect=register_complete_success_callback + ) + client = ProvisioningDeviceClient(provisioning_pipeline) + await client.register() + assert provisioning_pipeline.register.call_count == 1 + + @pytest.mark.it( + "Waits for the completion of the 'register' pipeline operation before returning" + ) + async def test_waits_for_pipeline_op_completion( + self, mocker, provisioning_pipeline, registration_result + ): + cb_mock = mocker.patch.object(async_adapter, "AwaitableCallback").return_value + cb_mock.completion.return_value = await create_completed_future(registration_result) + provisioning_pipeline.responses_enabled.__getitem__.return_value = True + + client = ProvisioningDeviceClient(provisioning_pipeline) + client._provisioning_payload = "payload" + await client.register() + + # Assert callback is sent to pipeline + assert provisioning_pipeline.register.call_args[1]["payload"] == "payload" + assert provisioning_pipeline.register.call_args[1]["callback"] is cb_mock + # Assert callback completion is waited upon + assert cb_mock.completion.call_count == 1 + + @pytest.mark.it("Returns the registration result that the pipeline returned") + async def test_verifies_registration_result_returned( + self, mocker, provisioning_pipeline, registration_result + ): + result = registration_result + + def register_complete_success_callback(payload, callback): + callback(result=result) + + mocker.patch.object( + provisioning_pipeline, "register", side_effect=register_complete_success_callback + ) + + client = ProvisioningDeviceClient(provisioning_pipeline) + result_returned = await client.register() + assert result_returned == result + + @pytest.mark.it( + "Raises a client error if the `register` pipeline operation calls back with a pipeline error" + ) @pytest.mark.parametrize( - "protocol", + "pipeline_error,client_error", [ - pytest.param("mqtt", id="mqtt"), - pytest.param(None, id="optional protocol"), - pytest.param("amqp", id="amqp", marks=xfail_notimplemented), - pytest.param("http", id="http", marks=xfail_notimplemented), + pytest.param( + pipeline_exceptions.ConnectionDroppedError, + client_exceptions.ConnectionDroppedError, + id="ConnectionDroppedError->ConnectionDroppedError", + ), + pytest.param( + pipeline_exceptions.ConnectionFailedError, + client_exceptions.ConnectionFailedError, + id="ConnectionFailedError->ConnectionFailedError", + ), + pytest.param( + pipeline_exceptions.UnauthorizedError, + client_exceptions.CredentialError, + id="UnauthorizedError->CredentialError", + ), + pytest.param( + pipeline_exceptions.ProtocolClientError, + client_exceptions.ClientError, + id="ProtocolClientError->ClientError", + ), + pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), ], ) - async def test_create_from_x509_cert(self, mocker, protocol): - client = ProvisioningDeviceClient.create_from_x509_certificate( - fake_provisioning_host, fake_registration_id, fake_id_scope, fake_x509() - ) - assert isinstance(client, ProvisioningDeviceClient) - assert client._provisioning_pipeline is not None - - -@pytest.mark.describe("ProvisioningDeviceClient") -class TestClientCallsPollingMachine(object): - @pytest.mark.it( - "Register calls register on polling machine with passed in callback and returns the registration result" - ) - async def test_client_register_success_calls_polling_machine_register_with_callback( - self, mocker, mock_polling_machine + async def test_raises_error_on_pipeline_op_error( + self, mocker, client_error, pipeline_error, provisioning_pipeline ): - # Override callback to pass successful result - def register_complete_success_callback(callback): - callback(result=create_success_result()) + error = pipeline_error() + + def register_complete_failure_callback(payload, callback): + callback(result=None, error=error) mocker.patch.object( - mock_polling_machine, "register", side_effect=register_complete_success_callback + provisioning_pipeline, "register", side_effect=register_complete_failure_callback ) - mqtt_provisioning_pipeline = mocker.MagicMock() - mock_polling_machine_init = mocker.patch( - "azure.iot.device.provisioning.aio.async_provisioning_device_client.PollingMachine" - ) - mock_polling_machine_init.return_value = mock_polling_machine + client = ProvisioningDeviceClient(provisioning_pipeline) - client = ProvisioningDeviceClient(mqtt_provisioning_pipeline) - result = await client.register() - - assert mock_polling_machine.register.call_count == 1 - assert callable(mock_polling_machine.register.call_args[1]["callback"]) - assert result is not None - assert result.registration_state == fake_registration_state - assert result.status == fake_status - assert result.registration_state == fake_registration_state - assert result.registration_state.device_id == fake_device_id - assert result.registration_state.assigned_hub == fake_assigned_hub - - @pytest.mark.it( - "Register calls register on polling machine with passed in callback and raises the error when an error has occured" - ) - async def test_client_register_failure_calls_polling_machine_register_with_callback( - self, mocker, mock_polling_machine - ): - # Override callback to pass successful result - def register_complete_failure_callback(callback): - callback(result=None, error=create_error()) - - mocker.patch.object( - mock_polling_machine, "register", side_effect=register_complete_failure_callback - ) - - mqtt_provisioning_pipeline = mocker.MagicMock() - mock_polling_machine_init = mocker.patch( - "azure.iot.device.provisioning.aio.async_provisioning_device_client.PollingMachine" - ) - mock_polling_machine_init.return_value = mock_polling_machine - - client = ProvisioningDeviceClient(mqtt_provisioning_pipeline) - with pytest.raises(RuntimeError): + with pytest.raises(client_error) as e_info: await client.register() - assert mock_polling_machine.register.call_count == 1 - assert callable(mock_polling_machine.register.call_args[1]["callback"]) + assert e_info.value.__cause__ is error + assert provisioning_pipeline.register.call_count == 1 - @pytest.mark.it("Cancel calls cancel on polling machine with passed in callback") - async def test_client_cancel_calls_polling_machine_cancel_with_callback( - self, mocker, mock_polling_machine - ): - mqtt_provisioning_pipeline = mocker.MagicMock() - mock_polling_machine_init = mocker.patch( - "azure.iot.device.provisioning.aio.async_provisioning_device_client.PollingMachine" - ) - mock_polling_machine_init.return_value = mock_polling_machine - client = ProvisioningDeviceClient(mqtt_provisioning_pipeline) - await client.cancel() +@pytest.mark.describe("ProvisioningDeviceClient - .set_provisioning_payload()") +class TestClientProvisioningPayload(object): + @pytest.mark.it("Sets the payload on the provisioning payload attribute") + @pytest.mark.parametrize( + "payload_input", + [ + pytest.param("Hello Hogwarts", id="String input"), + pytest.param(222, id="Integer input"), + pytest.param(object(), id="Object input"), + pytest.param(None, id="None input"), + pytest.param([1, "str"], id="List input"), + pytest.param({"a": 2}, id="Dictionary input"), + ], + ) + async def test_set_payload(self, mocker, payload_input): + provisioning_pipeline = mocker.MagicMock() - assert mock_polling_machine.cancel.call_count == 1 - assert callable(mock_polling_machine.cancel.call_args[1]["callback"]) + client = ProvisioningDeviceClient(provisioning_pipeline) + client.provisioning_payload = payload_input + assert client._provisioning_payload == payload_input + + @pytest.mark.it("Gets the payload from provisioning payload property") + @pytest.mark.parametrize( + "payload_input", + [ + pytest.param("Hello Hogwarts", id="String input"), + pytest.param(222, id="Integer input"), + pytest.param(object(), id="Object input"), + pytest.param(None, id="None input"), + pytest.param([1, "str"], id="List input"), + pytest.param({"a": 2}, id="Dictionary input"), + ], + ) + async def test_get_payload(self, mocker, payload_input): + provisioning_pipeline = mocker.MagicMock() + + client = ProvisioningDeviceClient(provisioning_pipeline) + client.provisioning_payload = payload_input + assert client.provisioning_payload == payload_input diff --git a/azure-iot-device/tests/provisioning/conftest.py b/azure-iot-device/tests/provisioning/conftest.py index e589ce895..88a5b1924 100644 --- a/azure-iot-device/tests/provisioning/conftest.py +++ b/azure-iot-device/tests/provisioning/conftest.py @@ -10,7 +10,6 @@ from azure.iot.device.provisioning.models.registration_result import ( RegistrationResult, RegistrationState, ) -from azure.iot.device.provisioning.internal.polling_machine import PollingMachine collect_ignore = [] @@ -25,16 +24,3 @@ fake_operation_id = "quidditch_world_cup" fake_request_id = "request_1234" fake_device_id = "MyNimbus2000" fake_assigned_hub = "Dumbledore'sArmy" - - -class FakePollingMachineSuccess(PollingMachine): - def register(self, callback): - callback(result=None, error=None) - - def cancel(self, callback): - callback() - - -@pytest.fixture -def mock_polling_machine(mocker): - return mocker.MagicMock(wraps=FakePollingMachineSuccess(mocker.MagicMock())) diff --git a/azure-iot-device/tests/provisioning/internal/test_polling_machine.py b/azure-iot-device/tests/provisioning/internal/test_polling_machine.py deleted file mode 100644 index 167470899..000000000 --- a/azure-iot-device/tests/provisioning/internal/test_polling_machine.py +++ /dev/null @@ -1,861 +0,0 @@ -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import datetime -import logging - -from mock import MagicMock -from azure.iot.device.provisioning.internal.request_response_provider import RequestResponseProvider -from azure.iot.device.provisioning.internal.polling_machine import PollingMachine -from azure.iot.device.provisioning.models.registration_result import RegistrationResult -from azure.iot.device.provisioning.pipeline import constant -import time - -logging.basicConfig(level=logging.DEBUG) - -fake_request_id = "Request1234" -fake_retry_after = "3" -fake_operation_id = "Operation4567" -fake_status = "Flying" -fake_device_id = "MyNimbus2000" -fake_assigned_hub = "Dumbledore'sArmy" -fake_sub_status = "FlyingOnHippogriff" -fake_created_dttm = datetime.datetime(2020, 5, 17) -fake_last_update_dttm = datetime.datetime(2020, 10, 17) -fake_etag = "HighQualityFlyingBroom" -fake_symmetric_key = "Zm9vYmFy" -fake_registration_id = "MyPensieve" -fake_id_scope = "Enchanted0000Ceiling7898" -fake_success_response_topic = "$dps/registrations/res/200/?" -fake_failure_response_topic = "$dps/registrations/res/400/?" -fake_greater_429_response_topic = "$dps/registrations/res/430/?" -fake_assigning_status = "assigning" -fake_assigned_status = "assigned" - - -class SomeRequestResponseProvider(RequestResponseProvider): - def receive_response(self, request_id, status_code, key_values, payload_str): - return super(SomeRequestResponseProvider, self)._receive_response( - request_id=request_id, - status_code=status_code, - key_value_dict=key_values, - response_payload=payload_str, - ) - - -@pytest.fixture -def mock_request_response_provider(mocker): - return mocker.MagicMock(spec=SomeRequestResponseProvider) - - -@pytest.fixture -def mock_polling_machine(mocker, mock_request_response_provider): - state_based_mqtt = MagicMock() - mock_init_request_response_provider = mocker.patch( - "azure.iot.device.provisioning.internal.polling_machine.RequestResponseProvider" - ) - mock_init_request_response_provider.return_value = mock_request_response_provider - mock_polling_machine = PollingMachine(state_based_mqtt) - return mock_polling_machine - - -@pytest.mark.describe("PollingMachine - Register") -class TestRegister(object): - @pytest.mark.it("Calls subscribe on RequestResponseProvider") - def test_register_calls_subscribe_on_request_response_provider(self, mock_polling_machine): - mock_request_response_provider = mock_polling_machine._request_response_provider - mock_polling_machine.register() - - assert mock_request_response_provider.enable_responses.call_count == 1 - assert ( - mock_request_response_provider.enable_responses.call_args[1]["callback"] - == mock_polling_machine._on_subscribe_completed - ) - - @pytest.mark.it("Completes subscription and calls send request on RequestResponseProvider") - def test_on_subscribe_completed_calls_send_register_request_on_request_response_provider( - self, mock_polling_machine, mocker - ): - mock_init_uuid = mocker.patch( - "azure.iot.device.provisioning.internal.polling_machine.uuid.uuid4" - ) - mock_init_uuid.return_value = fake_request_id - mock_init_query_timer = mocker.patch( - "azure.iot.device.provisioning.internal.polling_machine.Timer" - ) - mock_query_timer = mock_init_query_timer.return_value - mocker.patch.object(mock_query_timer, "start") - - mock_polling_machine.state = "initializing" - mock_request_response_provider = mock_polling_machine._request_response_provider - mocker.patch.object(mock_request_response_provider, "send_request") - - mock_polling_machine._on_subscribe_completed() - - assert mock_request_response_provider.send_request.call_count == 1 - assert ( - mock_request_response_provider.send_request.call_args_list[0][1]["request_id"] - == fake_request_id - ) - assert ( - mock_request_response_provider.send_request.call_args_list[0][1]["request_payload"] - == " " - ) - - -@pytest.mark.describe("PollingMachine - Register Response") -class TestRegisterResponse(object): - # Change the timeout so that the test does not hang for more time - constant.DEFAULT_TIMEOUT_INTERVAL = 0.2 - constant.DEFAULT_POLLING_INTERVAL = 0.01 - - @pytest.mark.it("Starts querying when there is a response with 'assigning' registration status") - def test_receive_register_response_assigning_does_query_with_operation_id(self, mocker): - state_based_mqtt = MagicMock() - mock_request_response_provider = SomeRequestResponseProvider(state_based_mqtt) - polling_machine = PollingMachine(state_based_mqtt) - - polling_machine._request_response_provider = mock_request_response_provider - mocker.patch.object(mock_request_response_provider, "enable_responses") - mocker.patch.object(state_based_mqtt, "send_request") - - # to transition into initializing - polling_machine.register(callback=MagicMock()) - - mock_init_uuid = mocker.patch( - "azure.iot.device.provisioning.internal.polling_machine.uuid.uuid4" - ) - mock_init_uuid.return_value = fake_request_id - key_value_dict = {} - key_value_dict["request_id"] = [fake_request_id, " "] - key_value_dict["retry-after"] = [fake_retry_after, " "] - - # to transition into registering - polling_machine._on_subscribe_completed() - - # reset mock to generate different request id for query - mock_init_uuid.reset_mock() - fake_request_id_query = "Request4567" - mock_init_uuid.return_value = fake_request_id_query - - fake_payload_result = ( - '{"operationId":"' + fake_operation_id + '","status":"' + fake_assigning_status + '"}' - ) - - mock_init_polling_timer = mocker.patch( - "azure.iot.device.provisioning.internal.polling_machine.Timer" - ) - - # Complete string pre-fixed by a b is the one that works for all versions of python - # or a encode on a string works for all versions of python - # For only python 3 , bytes(JsonString, "utf-8") can be done - mock_request_response_provider.receive_response( - fake_request_id, "200", key_value_dict, fake_payload_result - ) - - # call polling timer's time up call to simulate polling - time_up_call = mock_init_polling_timer.call_args[0][1] - time_up_call() - - assert state_based_mqtt.send_request.call_count == 2 - assert state_based_mqtt.send_request.call_args_list[0][1]["request_id"] == fake_request_id - assert state_based_mqtt.send_request.call_args_list[0][1]["request_payload"] == " " - - assert ( - state_based_mqtt.send_request.call_args_list[1][1]["request_id"] - == fake_request_id_query - ) - assert ( - state_based_mqtt.send_request.call_args_list[1][1]["operation_id"] == fake_operation_id - ) - assert state_based_mqtt.send_request.call_args_list[1][1]["request_payload"] == " " - - @pytest.mark.it( - "Completes registration process when there is a response with 'assigned' registration status" - ) - def test_receive_register_response_assigned_completes_registration(self, mocker): - state_based_mqtt = MagicMock() - mock_request_response_provider = SomeRequestResponseProvider(state_based_mqtt) - polling_machine = PollingMachine(state_based_mqtt) - - polling_machine._request_response_provider = mock_request_response_provider - mocker.patch.object(mock_request_response_provider, "enable_responses") - mocker.patch.object(state_based_mqtt, "send_request") - mocker.patch.object(mock_request_response_provider, "disconnect") - - # to transition into initializing - mock_callback = MagicMock() - polling_machine.register(callback=mock_callback) - - mock_init_uuid = mocker.patch( - "azure.iot.device.provisioning.internal.polling_machine.uuid.uuid4" - ) - mock_init_uuid.return_value = fake_request_id - key_value_dict = {} - key_value_dict["request_id"] = [fake_request_id, " "] - key_value_dict["retry-after"] = [fake_retry_after, " "] - - # to transition into registering - polling_machine._on_subscribe_completed() - - fake_registration_state = ( - '{"registrationId":"' - + fake_registration_id - + '","assignedHub":"' - + fake_assigned_hub - + '","deviceId":"' - + fake_device_id - + '","substatus":"' - + fake_sub_status - + '"}' - ) - - fake_payload_result = ( - '{"operationId":"' - + fake_operation_id - + '","status":"' - + fake_assigned_status - + '","registrationState":' - + fake_registration_state - + "}" - ) - - mock_request_response_provider.receive_response( - fake_request_id, "200", key_value_dict, fake_payload_result - ) - - polling_machine._on_disconnect_completed_register() - - assert state_based_mqtt.send_request.call_count == 1 - assert state_based_mqtt.send_request.call_args_list[0][1]["request_id"] == fake_request_id - assert state_based_mqtt.send_request.call_args_list[0][1]["request_payload"] == " " - - assert mock_callback.call_count == 1 - assert isinstance(mock_callback.call_args[1]["result"], RegistrationResult) - registration_result = mock_callback.call_args[1]["result"] - - registration_result.request_id == fake_request_id - registration_result.operation_id == fake_operation_id - registration_result.status == fake_assigned_status - registration_result.registration_state.device_id == fake_device_id - registration_result.registration_state.sub_status == fake_sub_status - - @pytest.mark.it( - "Calls callback of register with error when there is a failed response with status code > 300 & status code < 429" - ) - def test_receive_register_response_failure_calls_callback_of_register_error(self, mocker): - state_based_mqtt = MagicMock() - mock_request_response_provider = SomeRequestResponseProvider(state_based_mqtt) - polling_machine = PollingMachine(state_based_mqtt) - polling_machine._request_response_provider = mock_request_response_provider - - mocker.patch.object(mock_request_response_provider, "enable_responses") - mocker.patch.object(state_based_mqtt, "send_request") - mocker.patch.object(mock_request_response_provider, "disconnect") - - # to transition into initializing - mock_callback = MagicMock() - polling_machine.register(callback=mock_callback) - - mock_init_uuid = mocker.patch( - "azure.iot.device.provisioning.internal.polling_machine.uuid.uuid4" - ) - mock_init_uuid.return_value = fake_request_id - key_value_dict = {} - key_value_dict["request_id"] = [fake_request_id, " "] - - # to transition into registering - polling_machine._on_subscribe_completed() - - fake_payload_result = "HelloHogwarts" - mock_request_response_provider.receive_response( - fake_request_id, "400", key_value_dict, fake_payload_result - ) - - polling_machine._on_disconnect_completed_error() - - assert state_based_mqtt.send_request.call_count == 1 - assert state_based_mqtt.send_request.call_args_list[0][1]["request_id"] == fake_request_id - assert state_based_mqtt.send_request.call_args_list[0][1]["request_payload"] == " " - - assert mock_callback.call_count == 1 - assert isinstance(mock_callback.call_args[1]["error"], ValueError) - assert mock_callback.call_args[1]["error"].args[0] == "Incoming message failure" - - @pytest.mark.it( - "Calls callback of register with error when there is a response with unknown registration status" - ) - def test_receive_register_response_some_unknown_status_calls_callback_of_register_error( - self, mocker - ): - state_based_mqtt = MagicMock() - mock_request_response_provider = SomeRequestResponseProvider(state_based_mqtt) - polling_machine = PollingMachine(state_based_mqtt) - polling_machine._request_response_provider = mock_request_response_provider - - mocker.patch.object(mock_request_response_provider, "enable_responses") - mocker.patch.object(state_based_mqtt, "send_request") - mocker.patch.object(mock_request_response_provider, "disconnect") - - # to transition into initializing - mock_callback = MagicMock() - polling_machine.register(callback=mock_callback) - - mock_init_uuid = mocker.patch( - "azure.iot.device.provisioning.internal.polling_machine.uuid.uuid4" - ) - mock_init_uuid.return_value = fake_request_id - key_value_dict = {} - key_value_dict["request_id"] = [fake_request_id, " "] - - # to transition into registering - polling_machine._on_subscribe_completed() - - fake_unknown_status = "disabled" - fake_payload_result = ( - '{"operationId":"' + fake_operation_id + '","status":"' + fake_unknown_status + '"}' - ) - - mock_request_response_provider.receive_response( - fake_request_id, "200", key_value_dict, fake_payload_result - ) - - polling_machine._on_disconnect_completed_error() - - assert mock_callback.call_count == 1 - assert isinstance(mock_callback.call_args[1]["error"], ValueError) - assert ( - mock_callback.call_args[1]["error"].args[0] == "Other types of failure have occurred." - ) - assert mock_callback.call_args[1]["error"].args[1] == fake_payload_result - - @pytest.mark.it("Calls register again when there is a response with status code > 429") - def test_receive_register_response_greater_than_429_does_register_again(self, mocker): - state_based_mqtt = MagicMock() - mock_request_response_provider = SomeRequestResponseProvider(state_based_mqtt) - polling_machine = PollingMachine(state_based_mqtt) - polling_machine._request_response_provider = mock_request_response_provider - - mocker.patch.object(mock_request_response_provider, "enable_responses") - mocker.patch.object(state_based_mqtt, "send_request") - - # to transition into initializing - polling_machine.register(callback=MagicMock()) - - mock_init_uuid = mocker.patch( - "azure.iot.device.provisioning.internal.polling_machine.uuid.uuid4" - ) - mock_init_uuid.return_value = fake_request_id - key_value_dict = {} - key_value_dict["request_id"] = [fake_request_id, " "] - key_value_dict["retry-after"] = [fake_retry_after, " "] - - # to transition into registering - polling_machine._on_subscribe_completed() - - # reset mock to generate different request id for second time register - mock_init_uuid.reset_mock() - fake_request_id_2 = "Request4567" - mock_init_uuid.return_value = fake_request_id_2 - - fake_payload_result = "HelloHogwarts" - - mock_init_polling_timer = mocker.patch( - "azure.iot.device.provisioning.internal.polling_machine.Timer" - ) - - mock_request_response_provider.receive_response( - fake_request_id, "430", key_value_dict, fake_payload_result - ) - - # call polling timer's time up call to simulate polling - time_up_call = mock_init_polling_timer.call_args[0][1] - time_up_call() - - assert state_based_mqtt.send_request.call_count == 2 - assert state_based_mqtt.send_request.call_args_list[0][1]["request_id"] == fake_request_id - assert state_based_mqtt.send_request.call_args_list[0][1]["request_payload"] == " " - - assert state_based_mqtt.send_request.call_args_list[1][1]["request_id"] == fake_request_id_2 - assert state_based_mqtt.send_request.call_args_list[1][1]["request_payload"] == " " - - @pytest.mark.it("Calls callback of register with error when there is a time out") - def test_receive_register_response_after_query_time_passes_calls_callback_with_error( - self, mocker - ): - state_based_mqtt = MagicMock() - mock_request_response_provider = SomeRequestResponseProvider(state_based_mqtt) - polling_machine = PollingMachine(state_based_mqtt) - polling_machine._request_response_provider = mock_request_response_provider - - mocker.patch.object(mock_request_response_provider, "enable_responses") - mocker.patch.object(state_based_mqtt, "send_request") - - # to transition into initializing - mock_callback = MagicMock() - polling_machine.register(callback=mock_callback) - - mock_init_uuid = mocker.patch( - "azure.iot.device.provisioning.internal.polling_machine.uuid.uuid4" - ) - mock_init_uuid.return_value = fake_request_id - - # to transition into registering - polling_machine._on_subscribe_completed() - - # sleep so that it times out query - time.sleep(constant.DEFAULT_TIMEOUT_INTERVAL + 0.2) - - polling_machine._on_disconnect_completed_error() - - assert state_based_mqtt.send_request.call_count == 1 - assert state_based_mqtt.send_request.call_args_list[0][1]["request_id"] == fake_request_id - assert mock_callback.call_count == 1 - assert mock_callback.call_args[1]["error"].args[0] == "Time is up for query timer" - - -@pytest.mark.describe("PollingMachine - Query Response") -class TestQueryResponse(object): - # Change the timeout so that the test does not hang for more time - constant.DEFAULT_TIMEOUT_INTERVAL = 0.2 - constant.DEFAULT_POLLING_INTERVAL = 0.01 - - @pytest.mark.it( - "Does query again when there is a response with 'assigning' registration status" - ) - def test_receive_query_response_assigning_does_query_again_with_same_operation_id(self, mocker): - state_based_mqtt = MagicMock() - mock_request_response_provider = SomeRequestResponseProvider(state_based_mqtt) - polling_machine = PollingMachine(state_based_mqtt) - polling_machine._request_response_provider = mock_request_response_provider - - mocker.patch.object(mock_request_response_provider, "enable_responses") - mocker.patch.object(state_based_mqtt, "send_request") - - # to transition into initializing - polling_machine.register(callback=MagicMock()) - - mock_init_uuid = mocker.patch( - "azure.iot.device.provisioning.internal.polling_machine.uuid.uuid4" - ) - mock_init_uuid.return_value = fake_request_id - key_value_dict = {} - key_value_dict["request_id"] = [fake_request_id, " "] - - # to transition into registering - polling_machine._on_subscribe_completed() - - # reset mock to generate different request id for first query - mock_init_uuid.reset_mock() - fake_request_id_query = "Request4567" - mock_init_uuid.return_value = fake_request_id_query - key_value_dict_2 = {} - key_value_dict_2["request_id"] = [fake_request_id_query, " "] - - # fake_register_topic = fake_success_response_topic + "$rid={}".format(fake_request_id) - fake_register_payload_result = ( - '{"operationId":"' + fake_operation_id + '","status":"' + fake_assigning_status + '"}' - ) - - mock_init_polling_timer = mocker.patch( - "azure.iot.device.provisioning.internal.polling_machine.Timer" - ) - - # Response for register to transition to waiting polling - mock_request_response_provider.receive_response( - fake_request_id, "200", key_value_dict, fake_register_payload_result - ) - - # call polling timer's time up call to simulate polling - time_up_call = mock_init_polling_timer.call_args[0][1] - time_up_call() - - # reset mock to generate different request id for second query - mock_init_uuid.reset_mock() - fake_request_id_query_2 = "Request7890" - mock_init_uuid.return_value = fake_request_id_query_2 - - fake_query_payload_result = ( - '{"operationId":"' + fake_operation_id + '","status":"' + fake_assigning_status + '"}' - ) - - mock_init_polling_timer.reset_mock() - - mock_request_response_provider.receive_response( - fake_request_id_query, "200", key_value_dict_2, fake_query_payload_result - ) - - # call polling timer's time up call to simulate polling - time_up_call = mock_init_polling_timer.call_args[0][1] - time_up_call() - - assert state_based_mqtt.send_request.call_count == 3 - assert state_based_mqtt.send_request.call_args_list[0][1]["request_id"] == fake_request_id - assert state_based_mqtt.send_request.call_args_list[0][1]["request_payload"] == " " - - assert ( - state_based_mqtt.send_request.call_args_list[1][1]["request_id"] - == fake_request_id_query - ) - assert ( - state_based_mqtt.send_request.call_args_list[1][1]["operation_id"] == fake_operation_id - ) - assert state_based_mqtt.send_request.call_args_list[1][1]["request_payload"] == " " - - assert ( - state_based_mqtt.send_request.call_args_list[2][1]["request_id"] - == fake_request_id_query_2 - ) - assert ( - state_based_mqtt.send_request.call_args_list[2][1]["operation_id"] == fake_operation_id - ) - assert state_based_mqtt.send_request.call_args_list[2][1]["request_payload"] == " " - - @pytest.mark.it( - "Completes registration process when there is a query response with 'assigned' registration status" - ) - def test_receive_query_response_assigned_completes_registration(self, mocker): - state_based_mqtt = MagicMock() - mock_request_response_provider = SomeRequestResponseProvider(state_based_mqtt) - polling_machine = PollingMachine(state_based_mqtt) - polling_machine._request_response_provider = mock_request_response_provider - - mocker.patch.object(mock_request_response_provider, "enable_responses") - mocker.patch.object(state_based_mqtt, "send_request") - mocker.patch.object(mock_request_response_provider, "disconnect") - - # to transition into initializing - mock_callback = MagicMock() - polling_machine.register(callback=mock_callback) - - mock_init_uuid = mocker.patch( - "azure.iot.device.provisioning.internal.polling_machine.uuid.uuid4" - ) - mock_init_uuid.return_value = fake_request_id - key_value_dict = {} - key_value_dict["request_id"] = [fake_request_id, " "] - - # to transition into registering - polling_machine._on_subscribe_completed() - - # reset mock to generate different request id for first query - mock_init_uuid.reset_mock() - fake_request_id_query = "Request4567" - mock_init_uuid.return_value = fake_request_id_query - key_value_dict_2 = {} - key_value_dict_2["request_id"] = [fake_request_id_query, " "] - - fake_register_payload_result = ( - '{"operationId":"' + fake_operation_id + '","status":"' + fake_assigning_status + '"}' - ) - - mock_init_polling_timer = mocker.patch( - "azure.iot.device.provisioning.internal.polling_machine.Timer" - ) - - # Response for register to transition to waiting and polling - mock_request_response_provider.receive_response( - fake_request_id, "200", key_value_dict, fake_register_payload_result - ) - - # call polling timer's time up call to simulate polling - time_up_call = mock_init_polling_timer.call_args[0][1] - time_up_call() - - fake_registration_state = ( - '{"registrationId":"' - + fake_registration_id - + '","assignedHub":"' - + fake_assigned_hub - + '","deviceId":"' - + fake_device_id - + '","substatus":"' - + fake_sub_status - + '"}' - ) - - fake_query_payload_result = ( - '{"operationId":"' - + fake_operation_id - + '","status":"' - + fake_assigned_status - + '","registrationState":' - + fake_registration_state - + "}" - ) - - # Response for query - mock_request_response_provider.receive_response( - fake_request_id_query, "200", key_value_dict_2, fake_query_payload_result - ) - - polling_machine._on_disconnect_completed_register() - - assert state_based_mqtt.send_request.call_count == 2 - assert state_based_mqtt.send_request.call_args_list[0][1]["request_id"] == fake_request_id - assert state_based_mqtt.send_request.call_args_list[0][1]["request_payload"] == " " - - assert ( - state_based_mqtt.send_request.call_args_list[1][1]["request_id"] - == fake_request_id_query - ) - assert ( - state_based_mqtt.send_request.call_args_list[1][1]["operation_id"] == fake_operation_id - ) - assert state_based_mqtt.send_request.call_args_list[1][1]["request_payload"] == " " - - assert mock_callback.call_count == 1 - assert isinstance(mock_callback.call_args[1]["result"], RegistrationResult) - - @pytest.mark.it( - "Calls callback of register with error when there is a failed query response with status code > 300 & status code < 429" - ) - def test_receive_query_response_failure_calls_callback_of_register_error(self, mocker): - state_based_mqtt = MagicMock() - mock_request_response_provider = SomeRequestResponseProvider(state_based_mqtt) - polling_machine = PollingMachine(state_based_mqtt) - polling_machine._request_response_provider = mock_request_response_provider - - mocker.patch.object(mock_request_response_provider, "enable_responses") - mocker.patch.object(state_based_mqtt, "send_request") - mocker.patch.object(mock_request_response_provider, "disconnect") - - # to transition into initializing - mock_callback = MagicMock() - polling_machine.register(callback=mock_callback) - - mock_init_uuid = mocker.patch( - "azure.iot.device.provisioning.internal.polling_machine.uuid.uuid4" - ) - mock_init_uuid.return_value = fake_request_id - key_value_dict = {} - key_value_dict["request_id"] = [fake_request_id, " "] - - # to transition into registering - polling_machine._on_subscribe_completed() - - # reset mock to generate different request id for first query - mock_init_uuid.reset_mock() - fake_request_id_query = "Request4567" - mock_init_uuid.return_value = fake_request_id_query - key_value_dict_2 = {} - key_value_dict_2["request_id"] = [fake_request_id_query, " "] - - fake_register_payload_result = ( - '{"operationId":"' + fake_operation_id + '","status":"' + fake_assigning_status + '"}' - ) - - mock_init_polling_timer = mocker.patch( - "azure.iot.device.provisioning.internal.polling_machine.Timer" - ) - - # Response for register to transition to waiting and polling - mock_request_response_provider.receive_response( - fake_request_id, "200", key_value_dict, fake_register_payload_result - ) - - # call polling timer's time up call to simulate polling - time_up_call = mock_init_polling_timer.call_args[0][1] - time_up_call() - - fake_query_payload_result = "HelloHogwarts" - - # Response for query - mock_request_response_provider.receive_response( - fake_request_id_query, "400", key_value_dict_2, fake_query_payload_result - ) - - polling_machine._on_disconnect_completed_error() - - assert state_based_mqtt.send_request.call_count == 2 - assert state_based_mqtt.send_request.call_args_list[0][1]["request_id"] == fake_request_id - assert state_based_mqtt.send_request.call_args_list[0][1]["request_payload"] == " " - - assert ( - state_based_mqtt.send_request.call_args_list[1][1]["request_id"] - == fake_request_id_query - ) - assert ( - state_based_mqtt.send_request.call_args_list[1][1]["operation_id"] == fake_operation_id - ) - assert state_based_mqtt.send_request.call_args_list[1][1]["request_payload"] == " " - - assert mock_callback.call_count == 1 - assert isinstance(mock_callback.call_args[1]["error"], ValueError) - assert mock_callback.call_args[1]["error"].args[0] == "Incoming message failure" - - @pytest.mark.it("Calls query again when there is a response with status code > 429") - def test_receive_query_response_greater_than_429_does_query_again_with_same_operation_id( - self, mocker - ): - state_based_mqtt = MagicMock() - mock_request_response_provider = SomeRequestResponseProvider(state_based_mqtt) - polling_machine = PollingMachine(state_based_mqtt) - polling_machine._request_response_provider = mock_request_response_provider - - mocker.patch.object(mock_request_response_provider, "enable_responses") - mocker.patch.object(state_based_mqtt, "send_request") - mocker.patch.object(mock_request_response_provider, "disconnect") - - # to transition into initializing - polling_machine.register(callback=MagicMock()) - - mock_init_uuid = mocker.patch( - "azure.iot.device.provisioning.internal.polling_machine.uuid.uuid4" - ) - mock_init_uuid.return_value = fake_request_id - key_value_dict = {} - key_value_dict["request_id"] = [fake_request_id, " "] - - # to transition into registering - polling_machine._on_subscribe_completed() - - # reset mock to generate different request id for first query - mock_init_uuid.reset_mock() - fake_request_id_query = "Request4567" - mock_init_uuid.return_value = fake_request_id_query - key_value_dict_2 = {} - key_value_dict_2["request_id"] = [fake_request_id_query, " "] - - # fake_register_topic = fake_success_response_topic + "$rid={}".format(fake_request_id) - fake_register_payload_result = ( - '{"operationId":"' + fake_operation_id + '","status":"' + fake_assigning_status + '"}' - ) - - mock_init_polling_timer = mocker.patch( - "azure.iot.device.provisioning.internal.polling_machine.Timer" - ) - - # Response for register to transition to waiting polling - mock_request_response_provider.receive_response( - fake_request_id, "200", key_value_dict, fake_register_payload_result - ) - - # call polling timer's time up call to simulate polling - time_up_call = mock_init_polling_timer.call_args[0][1] - time_up_call() - - # reset mock to generate different request id for second query - mock_init_uuid.reset_mock() - fake_request_id_query_2 = "Request7890" - mock_init_uuid.return_value = fake_request_id_query_2 - - fake_query_payload_result = "HelloHogwarts" - - mock_init_polling_timer.reset_mock() - - # Response for query - mock_request_response_provider.receive_response( - fake_request_id_query, "430", key_value_dict_2, fake_query_payload_result - ) - - # call polling timer's time up call to simulate polling - time_up_call = mock_init_polling_timer.call_args[0][1] - time_up_call() - - assert state_based_mqtt.send_request.call_count == 3 - assert state_based_mqtt.send_request.call_args_list[0][1]["request_id"] == fake_request_id - assert state_based_mqtt.send_request.call_args_list[0][1]["request_payload"] == " " - - assert ( - state_based_mqtt.send_request.call_args_list[1][1]["request_id"] - == fake_request_id_query - ) - assert ( - state_based_mqtt.send_request.call_args_list[1][1]["operation_id"] == fake_operation_id - ) - assert state_based_mqtt.send_request.call_args_list[1][1]["request_payload"] == " " - - assert ( - state_based_mqtt.send_request.call_args_list[2][1]["request_id"] - == fake_request_id_query_2 - ) - assert ( - state_based_mqtt.send_request.call_args_list[2][1]["operation_id"] == fake_operation_id - ) - assert state_based_mqtt.send_request.call_args_list[2][1]["request_payload"] == " " - - -@pytest.mark.describe("PollingMachine - Cancel") -class TestCancel(object): - # Change the timeout so that the test does not hang for more time - constant.DEFAULT_TIMEOUT_INTERVAL = 0.9 - constant.DEFAULT_POLLING_INTERVAL = 0.09 - - @pytest.mark.it("Calls disconnect on RequestResponseProvider and calls callback") - def test_cancel_disconnects_on_request_response_provider_and_calls_callback( - self, mock_polling_machine - ): - mock_request_response_provider = mock_polling_machine._request_response_provider - - mock_polling_machine.register(callback=MagicMock()) - - mock_cancel_callback = MagicMock() - mock_polling_machine.cancel(mock_cancel_callback) - - mock_request_response_provider.disconnect.assert_called_once_with( - callback=mock_polling_machine._on_disconnect_completed_cancel - ) - - mock_polling_machine._on_disconnect_completed_cancel() - - assert mock_cancel_callback.call_count == 1 - - @pytest.mark.it("Calls disconnect on RequestResponseProvider, clears timers and calls callback") - def test_register_and_cancel_clears_timers_and_disconnects(self, mocker): - state_based_mqtt = MagicMock() - mock_request_response_provider = SomeRequestResponseProvider(state_based_mqtt) - polling_machine = PollingMachine(state_based_mqtt) - polling_machine._request_response_provider = mock_request_response_provider - - mocker.patch.object(mock_request_response_provider, "enable_responses") - mocker.patch.object(state_based_mqtt, "send_request") - mocker.patch.object(mock_request_response_provider, "disconnect") - - # to transition into initializing - polling_machine.register(callback=MagicMock()) - - mock_init_uuid = mocker.patch( - "azure.iot.device.provisioning.internal.polling_machine.uuid.uuid4" - ) - mock_init_uuid.return_value = fake_request_id - key_value_dict = {} - key_value_dict["request_id"] = [fake_request_id, " "] - - # to transition into registering - polling_machine._on_subscribe_completed() - - # reset mock to generate different request id for query - mock_init_uuid.reset_mock() - fake_request_id_query = "Request4567" - mock_init_uuid.return_value = fake_request_id_query - key_value_dict_2 = {} - key_value_dict_2["request_id"] = [fake_request_id_query, " "] - - fake_payload_result = ( - '{"operationId":"' + fake_operation_id + '","status":"' + fake_assigning_status + '"}' - ) - - mock_request_response_provider.receive_response( - fake_request_id, "200", key_value_dict, fake_payload_result - ) - - polling_timer = polling_machine._polling_timer - query_timer = polling_machine._query_timer - poling_timer_cancel = mocker.patch.object(polling_timer, "cancel") - query_timer_cancel = mocker.patch.object(query_timer, "cancel") - - mock_cancel_callback = MagicMock() - polling_machine.cancel(mock_cancel_callback) - - assert poling_timer_cancel.call_count == 1 - assert query_timer_cancel.call_count == 1 - - assert mock_request_response_provider.disconnect.call_count == 1 - polling_machine._on_disconnect_completed_cancel() - - assert mock_cancel_callback.call_count == 1 diff --git a/azure-iot-device/tests/provisioning/internal/test_registration_query_status_result.py b/azure-iot-device/tests/provisioning/internal/test_registration_query_status_result.py deleted file mode 100644 index 4168ed146..000000000 --- a/azure-iot-device/tests/provisioning/internal/test_registration_query_status_result.py +++ /dev/null @@ -1,46 +0,0 @@ -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import logging -import datetime -from azure.iot.device.provisioning.internal.registration_query_status_result import ( - RegistrationQueryStatusResult, -) - -logging.basicConfig(level=logging.DEBUG) - -fake_request_id = "Request1234" -fake_retry_after = 6 -fake_operation_id = "Operation4567" -fake_status = "Flying" -fake_device_id = "MyNimbus2000" -fake_assigned_hub = "Dumbledore'sArmy" -fake_sub_status = "FlyingOnHippogriff" -fake_created_dttm = datetime.datetime(2020, 5, 17) -fake_last_update_dttm = datetime.datetime(2020, 10, 17) -fake_etag = "HighQualityFlyingBroom" - - -@pytest.mark.describe("RegistrationQueryStatusResult") -class TestRegistrationQueryStatusResult(object): - @pytest.mark.it("Instantiates correctly") - def test_registration_status_query_result_instantiated_correctly(self): - intermediate_result = RegistrationQueryStatusResult( - fake_request_id, fake_retry_after, fake_operation_id, fake_status - ) - assert intermediate_result.request_id == fake_request_id - assert intermediate_result.retry_after == fake_retry_after - assert intermediate_result.operation_id == fake_operation_id - assert intermediate_result.status == fake_status - - @pytest.mark.it("Has request id that does not have setter") - def test_rid_is_not_settable(self): - registration_result = RegistrationQueryStatusResult( - "RequestId123", "Operation456", "emitted", None - ) - with pytest.raises(AttributeError, match="can't set attribute"): - registration_result.request_id = "MyNimbus2000" diff --git a/azure-iot-device/tests/provisioning/internal/test_request_response_provider.py b/azure-iot-device/tests/provisioning/internal/test_request_response_provider.py deleted file mode 100644 index bd1c6b5b6..000000000 --- a/azure-iot-device/tests/provisioning/internal/test_request_response_provider.py +++ /dev/null @@ -1,160 +0,0 @@ -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import logging -import six.moves.urllib as urllib -from mock import MagicMock -from azure.iot.device.provisioning.internal.request_response_provider import RequestResponseProvider - -logging.basicConfig(level=logging.DEBUG) - -fake_request_id = "Request1234" -fake_operation_id = "Operation4567" -fake_request_topic = "$ministryservices/wizardregistrations/$rid={}" -fake_subscribe_topic = "$dps/registrations/res/#" -fake_success_response_topic = "$dps/registrations/res/9999/?$rid={}".format(fake_request_id) -POS_STATUS_CODE_IN_TOPIC = 3 -POS_QUERY_PARAM_PORTION = 2 -POS_URL_PORTION = 1 - - -@pytest.fixture -def request_response_provider(): - provisioning_pipeline = MagicMock() - request_response_provider = RequestResponseProvider(provisioning_pipeline) - request_response_provider.on_response_received = MagicMock() - return request_response_provider - - -@pytest.mark.describe("RequestResponseProvider") -class TestRequestResponseProvider(object): - @pytest.mark.it("connect calls connect on state based provider with given callback") - def test_connect_calls_connect_on_provisioning_pipeline_with_provided_callback( - self, request_response_provider - ): - mock_provisioning_pipeline = request_response_provider._provisioning_pipeline - mock_callback = MagicMock() - request_response_provider.connect(mock_callback) - mock_provisioning_pipeline.connect.assert_called_once_with(callback=mock_callback) - - @pytest.mark.it("connect calls connect on state based provider with defined callback") - def test_connect_calls_connect_on_provisioning_pipeline(self, request_response_provider): - mock_provisioning_pipeline = request_response_provider._provisioning_pipeline - request_response_provider.connect() - mock_provisioning_pipeline.connect.assert_called_once_with( - callback=request_response_provider._on_connection_state_change - ) - - @pytest.mark.it("disconnect calls disconnect on state based provider with given callback") - def test_disconnect_calls_disconnect_on_provisioning_pipeline_with_provided_callback( - self, request_response_provider - ): - mock_provisioning_pipeline = request_response_provider._provisioning_pipeline - mock_callback = MagicMock() - request_response_provider.disconnect(mock_callback) - mock_provisioning_pipeline.disconnect.assert_called_once_with(callback=mock_callback) - - @pytest.mark.it("disconnect calls disconnect on state based provider with defined callback") - def test_disconnect_calls_disconnect_on_provisioning_pipeline(self, request_response_provider): - mock_provisioning_pipeline = request_response_provider._provisioning_pipeline - request_response_provider.disconnect() - mock_provisioning_pipeline.disconnect.assert_called_once_with( - callback=request_response_provider._on_connection_state_change - ) - - @pytest.mark.it("Send request calls send request on pipeline with request") - def test_send_request_calls_publish_on_provisioning_pipeline(self, request_response_provider): - mock_provisioning_pipeline = request_response_provider._provisioning_pipeline - req = "Leviosa" - mock_callback = MagicMock() - request_response_provider.send_request( - request_id=fake_request_id, - request_payload=req, - operation_id=fake_operation_id, - callback_on_response=mock_callback, - ) - assert mock_provisioning_pipeline.send_request.call_count == 1 - print(mock_provisioning_pipeline.send_request.call_args) - assert ( - mock_provisioning_pipeline.send_request.call_args[1]["operation_id"] - == fake_operation_id - ) - assert mock_provisioning_pipeline.send_request.call_args[1]["request_id"] == fake_request_id - - assert mock_provisioning_pipeline.send_request.call_args[1]["request_payload"] == req - - @pytest.mark.it( - "Enable_responses calls enable_responses on pipeline with topic and given callback" - ) - def test_enable_responses_calls_enable_responses_on_provisioning_pipeline_with_provided_callback( - self, request_response_provider - ): - mock_provisioning_pipeline = request_response_provider._provisioning_pipeline - mock_callback = MagicMock() - request_response_provider.enable_responses(mock_callback) - mock_provisioning_pipeline.enable_responses.assert_called_once_with(callback=mock_callback) - - @pytest.mark.it( - "Enable_responses calls enable_responses on pipeline with topic and defined callback" - ) - def test_enable_responses_calls_enable_responses_on_provisioning_pipeline( - self, request_response_provider - ): - mock_provisioning_pipeline = request_response_provider._provisioning_pipeline - request_response_provider.enable_responses() - mock_provisioning_pipeline.enable_responses.assert_called_once_with( - callback=request_response_provider._on_subscribe_completed - ) - - @pytest.mark.it("Unsubscribe calls unsubscribe on pipeline with topic and given callback") - def test_disable_responses_calls_disable_responses_on_provisioning_pipeline_with_provided_callback( - self, request_response_provider - ): - mock_provisioning_pipeline = request_response_provider._provisioning_pipeline - mock_callback = MagicMock() - request_response_provider.disable_responses(mock_callback) - mock_provisioning_pipeline.disable_responses.assert_called_once_with(callback=mock_callback) - - @pytest.mark.it( - "Disable_response calls disable_response on pipeline with topic and defined callback" - ) - def test_disable_responses_calls_disable_responses_on_provisioning_pipeline( - self, request_response_provider - ): - mock_provisioning_pipeline = request_response_provider._provisioning_pipeline - request_response_provider.disable_responses() - mock_provisioning_pipeline.disable_responses.assert_called_once_with( - callback=request_response_provider._on_unsubscribe_completed - ) - - @pytest.mark.it("Receives message and calls callback passed with payload") - def test_on_provider_message_received_receives_response_and_calls_callback( - self, request_response_provider - ): - mock_provisioning_pipeline = request_response_provider._provisioning_pipeline - req = "Leviosa" - - mock_callback = MagicMock() - request_response_provider.send_request( - request_id=fake_request_id, - request_payload=req, - operation_id=fake_operation_id, - callback_on_response=mock_callback, - ) - assigning_status = "assigning" - - payload = '{"operationId":"' + fake_operation_id + '","status":"' + assigning_status + '"}' - - topic_parts = fake_success_response_topic.split("$") - key_value_dict = urllib.parse.parse_qs(topic_parts[POS_QUERY_PARAM_PORTION]) - - mock_payload = payload.encode("utf-8") - mock_provisioning_pipeline.on_message_received( - fake_request_id, "202", key_value_dict, mock_payload - ) - - mock_callback.assert_called_once_with(fake_request_id, "202", key_value_dict, mock_payload) diff --git a/azure-iot-device/tests/provisioning/models/test_registration_result.py b/azure-iot-device/tests/provisioning/models/test_registration_result.py index d929b71f1..24ca3741e 100644 --- a/azure-iot-device/tests/provisioning/models/test_registration_result.py +++ b/azure-iot-device/tests/provisioning/models/test_registration_result.py @@ -11,6 +11,7 @@ from azure.iot.device.provisioning.models.registration_result import ( RegistrationResult, RegistrationState, ) +import json logging.basicConfig(level=logging.DEBUG) @@ -23,16 +24,16 @@ fake_sub_status = "FlyingOnHippogriff" fake_created_dttm = datetime.datetime(2020, 5, 17) fake_last_update_dttm = datetime.datetime(2020, 10, 17) fake_etag = "HighQualityFlyingBroom" +fake_payload = "petrificus totalus" @pytest.mark.describe("RegistrationResult") class TestRegistrationResult(object): @pytest.mark.it("Instantiates correctly") def test_registration_result_instantiated_correctly(self): - fake_registration_state = create_registraion_state() + fake_registration_state = create_registration_state() registration_result = create_registration_result(fake_registration_state) - assert registration_result.request_id == fake_request_id assert registration_result.operation_id == fake_operation_id assert registration_result.status == fake_status assert registration_result.registration_state == fake_registration_state @@ -46,7 +47,7 @@ class TestRegistrationResult(object): @pytest.mark.it("Has a to string representation composed of registration state and status") def test_registration_result_to_string(self): - fake_registration_state = create_registraion_state() + fake_registration_state = create_registration_state() registration_result = create_registration_result(fake_registration_state) string_repr = "\n".join([str(fake_registration_state), fake_status]) @@ -55,7 +56,6 @@ class TestRegistrationResult(object): @pytest.mark.parametrize( "input_setter_code", [ - pytest.param('registration_result.request_id = "RequestId123"', id="Request Id"), pytest.param('registration_result.operation_id = "WhompingWillow"', id="Operation Id"), pytest.param('registration_result.status = "Apparating"', id="Status"), pytest.param( @@ -89,7 +89,7 @@ class TestRegistrationResult(object): ) @pytest.mark.it("Has `RegistrationState` with properties that do not have setter") def test_some_properties_of_state_are_not_settable(self, input_setter_code): - registration_state = create_registraion_state() # noqa: F841 + registration_state = create_registration_state() # noqa: F841 with pytest.raises(AttributeError, match="can't set attribute"): exec(input_setter_code) @@ -97,14 +97,27 @@ class TestRegistrationResult(object): @pytest.mark.it( "Has a to string representation composed of device id, assigned hub and sub status" ) - def test_registration_state_to_string(self): - registration_state = create_registraion_state() + def test_registration_state_to_string_without_payload(self): + registration_state = create_registration_state() + # Serializes the __dict__ of every object instead of the object itself. + # Helpful for all sorts of complex objects. + json_payload = json.dumps(None, default=lambda o: o.__dict__, sort_keys=True) - string_repr = "\n".join([fake_device_id, fake_assigned_hub, fake_sub_status]) + string_repr = "\n".join([fake_device_id, fake_assigned_hub, fake_sub_status, json_payload]) + assert str(registration_state) == string_repr + + @pytest.mark.it( + "Has a to string representation composed of device id, assigned hub, sub status and response payload" + ) + def test_registration_state_to_string_with_payload(self): + registration_state = create_registration_state(fake_payload) + json_payload = json.dumps(fake_payload, default=lambda o: o.__dict__, sort_keys=True) + + string_repr = "\n".join([fake_device_id, fake_assigned_hub, fake_sub_status, json_payload]) assert str(registration_state) == string_repr -def create_registraion_state(): +def create_registration_state(payload=None): return RegistrationState( fake_device_id, fake_assigned_hub, @@ -112,8 +125,9 @@ def create_registraion_state(): fake_created_dttm, fake_last_update_dttm, fake_etag, + payload, ) def create_registration_result(registration_state=None): - return RegistrationResult(fake_request_id, fake_operation_id, fake_status, registration_state) + return RegistrationResult(fake_operation_id, fake_status, registration_state) diff --git a/azure-iot-device/tests/provisioning/pipeline/conftest.py b/azure-iot-device/tests/provisioning/pipeline/conftest.py index 4168a3db8..f0fbc1a79 100644 --- a/azure-iot-device/tests/provisioning/pipeline/conftest.py +++ b/azure-iot-device/tests/provisioning/pipeline/conftest.py @@ -5,11 +5,9 @@ # -------------------------------------------------------------------------- from tests.common.pipeline.fixtures import ( - callback, - fake_exception, - fake_base_exception, - event, fake_pipeline_thread, fake_non_pipeline_thread, unhandled_error_handler, + arbitrary_op, + arbitrary_event, ) diff --git a/azure-iot-device/tests/provisioning/pipeline/helpers.py b/azure-iot-device/tests/provisioning/pipeline/helpers.py index bcb761977..dccc21973 100644 --- a/azure-iot-device/tests/provisioning/pipeline/helpers.py +++ b/azure-iot-device/tests/provisioning/pipeline/helpers.py @@ -3,22 +3,17 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -from azure.iot.device.provisioning.pipeline import ( - pipeline_events_provisioning, - pipeline_ops_provisioning, -) +from azure.iot.device.provisioning.pipeline import pipeline_ops_provisioning all_provisioning_ops = [ pipeline_ops_provisioning.SetSymmetricKeySecurityClientOperation, pipeline_ops_provisioning.SetX509SecurityClientOperation, pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation, - pipeline_ops_provisioning.SendRegistrationRequestOperation, - pipeline_ops_provisioning.SendQueryRequestOperation, + pipeline_ops_provisioning.RegisterOperation, + pipeline_ops_provisioning.PollStatusOperation, ] fake_key_values = {} fake_key_values["request_id"] = ["request_1234", " "] fake_key_values["retry-after"] = ["300", " "] fake_key_values["name"] = ["hermione", " "] - -all_provisioning_events = [pipeline_events_provisioning.RegistrationResponseEvent] diff --git a/azure-iot-device/tests/provisioning/pipeline/test_config.py b/azure-iot-device/tests/provisioning/pipeline/test_config.py new file mode 100644 index 000000000..6ceb64814 --- /dev/null +++ b/azure-iot-device/tests/provisioning/pipeline/test_config.py @@ -0,0 +1,17 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import pytest +import logging +from tests.common.pipeline.pipeline_config_test import PipelineConfigInstantiationTestBase +from azure.iot.device.provisioning.pipeline.config import ProvisioningPipelineConfig + + +@pytest.mark.describe("ProvisioningPipelineConfig - Instantiation") +class TestProvisioningPipelineConfigInstantiation(PipelineConfigInstantiationTestBase): + @pytest.fixture + def config_cls(self): + # This fixture is needed for the parent class + return ProvisioningPipelineConfig diff --git a/azure-iot-device/tests/provisioning/pipeline/test_pipeline_events_provisioning.py b/azure-iot-device/tests/provisioning/pipeline/test_pipeline_events_provisioning.py deleted file mode 100644 index b640be74d..000000000 --- a/azure-iot-device/tests/provisioning/pipeline/test_pipeline_events_provisioning.py +++ /dev/null @@ -1,19 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import sys -import logging -from azure.iot.device.provisioning.pipeline import pipeline_events_provisioning -from tests.common.pipeline import pipeline_data_object_test - -logging.basicConfig(level=logging.DEBUG) -this_module = sys.modules[__name__] - -pipeline_data_object_test.add_event_test( - cls=pipeline_events_provisioning.RegistrationResponseEvent, - module=this_module, - positional_arguments=["request_id", "status_code", "key_values", "response_payload"], - keyword_arguments={}, -) diff --git a/azure-iot-device/tests/provisioning/pipeline/test_pipeline_ops_provisioning.py b/azure-iot-device/tests/provisioning/pipeline/test_pipeline_ops_provisioning.py index 62056ee72..b2bf5e0c1 100644 --- a/azure-iot-device/tests/provisioning/pipeline/test_pipeline_ops_provisioning.py +++ b/azure-iot-device/tests/provisioning/pipeline/test_pipeline_ops_provisioning.py @@ -3,35 +3,253 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +import pytest import sys import logging from azure.iot.device.provisioning.pipeline import pipeline_ops_provisioning -from tests.common.pipeline import pipeline_data_object_test +from tests.common.pipeline import pipeline_ops_test logging.basicConfig(level=logging.DEBUG) this_module = sys.modules[__name__] +pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") -pipeline_data_object_test.add_operation_test( - cls=pipeline_ops_provisioning.SetSymmetricKeySecurityClientOperation, - module=this_module, - positional_arguments=["security_client"], - keyword_arguments={"callback": None}, + +class SetSymmetricKeySecurityClientOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_provisioning.SetSymmetricKeySecurityClientOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = {"security_client": mocker.MagicMock(), "callback": mocker.MagicMock()} + return kwargs + + +class SetSymmetricKeySecurityClientOperationInstantiationTests( + SetSymmetricKeySecurityClientOperationTestConfig +): + @pytest.mark.it( + "Initializes 'security_client' attribute with the provided 'security_client' parameter" + ) + def test_security_client(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.security_client is init_kwargs["security_client"] + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_provisioning.SetSymmetricKeySecurityClientOperation, + op_test_config_class=SetSymmetricKeySecurityClientOperationTestConfig, + extended_op_instantiation_test_class=SetSymmetricKeySecurityClientOperationInstantiationTests, ) -pipeline_data_object_test.add_operation_test( - cls=pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation, - module=this_module, - positional_arguments=["provisioning_host", "registration_id", "id_scope"], - keyword_arguments={"client_cert": None, "sas_token": None, "callback": None}, + + +class SetX509SecurityClientOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_provisioning.SetX509SecurityClientOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = {"security_client": mocker.MagicMock(), "callback": mocker.MagicMock()} + return kwargs + + +class SetX509SecurityClientOperationInstantiationTests(SetX509SecurityClientOperationTestConfig): + @pytest.mark.it( + "Initializes 'security_client' attribute with the provided 'security_client' parameter" + ) + def test_security_client(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.security_client is init_kwargs["security_client"] + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_provisioning.SetX509SecurityClientOperation, + op_test_config_class=SetX509SecurityClientOperationTestConfig, + extended_op_instantiation_test_class=SetX509SecurityClientOperationInstantiationTests, ) -pipeline_data_object_test.add_operation_test( - cls=pipeline_ops_provisioning.SendRegistrationRequestOperation, - module=this_module, - positional_arguments=["request_id", "request_payload"], - keyword_arguments={"callback": None}, + + +class SetProvisioningClientConnectionArgsOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = { + "provisioning_host": "some_provisioning_host", + "registration_id": "some_registration_id", + "id_scope": "some_id_scope", + "callback": mocker.MagicMock(), + "client_cert": "some_client_cert", + "sas_token": "some_sas_token", + } + return kwargs + + +class SetProvisioningClientConnectionArgsOperationInstantiationTests( + SetProvisioningClientConnectionArgsOperationTestConfig +): + @pytest.mark.it( + "Initializes 'provisioning_host' attribute with the provided 'provisioning_host' parameter" + ) + def test_provisioning_host(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.provisioning_host is init_kwargs["provisioning_host"] + + @pytest.mark.it( + "Initializes 'registration_id' attribute with the provided 'registration_id' parameter" + ) + def test_registration_id(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.registration_id is init_kwargs["registration_id"] + + @pytest.mark.it("Initializes 'id_scope' attribute with the provided 'id_scope' parameter") + def test_id_scope(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.id_scope is init_kwargs["id_scope"] + + @pytest.mark.it("Initializes 'client_cert' attribute with the provided 'client_cert' parameter") + def test_client_cert(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.client_cert is init_kwargs["client_cert"] + + @pytest.mark.it( + "Initializes 'client_cert' attribute to None if no 'client_cert' parameter is provided" + ) + def test_client_cert_default(self, cls_type, init_kwargs): + del init_kwargs["client_cert"] + op = cls_type(**init_kwargs) + assert op.client_cert is None + + @pytest.mark.it("Initializes 'sas_token' attribute with the provided 'sas_token' parameter") + def test_sas_token(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.sas_token is init_kwargs["sas_token"] + + @pytest.mark.it( + "Initializes 'sas_token' attribute to None if no 'sas_token' parameter is provided" + ) + def test_sas_token_default(self, cls_type, init_kwargs): + del init_kwargs["sas_token"] + op = cls_type(**init_kwargs) + assert op.sas_token is None + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation, + op_test_config_class=SetProvisioningClientConnectionArgsOperationTestConfig, + extended_op_instantiation_test_class=SetProvisioningClientConnectionArgsOperationInstantiationTests, ) -pipeline_data_object_test.add_operation_test( - cls=pipeline_ops_provisioning.SendQueryRequestOperation, - module=this_module, - positional_arguments=["request_id", "operation_id", "request_payload"], - keyword_arguments={"callback": None}, + + +class RegisterOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_provisioning.RegisterOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = { + "request_payload": "some_request_payload", + "registration_id": "some_registration_id", + "callback": mocker.MagicMock(), + } + return kwargs + + +class RegisterOperationInstantiationTests(RegisterOperationTestConfig): + @pytest.mark.it( + "Initializes 'request_payload' attribute with the provided 'request_payload' parameter" + ) + def test_request_payload(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.request_payload == init_kwargs["request_payload"] + + @pytest.mark.it( + "Initializes 'registration_id' attribute with the provided 'registration_id' parameter" + ) + def test_registration_id(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.registration_id == init_kwargs["registration_id"] + + @pytest.mark.it("Initializes 'retry_after_timer' attribute to None") + def test_retry_after_timer(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.retry_after_timer is None + + @pytest.mark.it("Initializes 'polling_timer' attribute to None") + def test_polling_timer(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.polling_timer is None + + @pytest.mark.it("Initializes 'provisioning_timeout_timer' attribute to None") + def test_provisioning_timeout_timer(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.provisioning_timeout_timer is None + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_provisioning.RegisterOperation, + op_test_config_class=RegisterOperationTestConfig, + extended_op_instantiation_test_class=RegisterOperationInstantiationTests, +) + + +class PollStatusOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_provisioning.PollStatusOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = { + "operation_id": "some_operation_id", + "request_payload": "some_request_payload", + "callback": mocker.MagicMock(), + } + return kwargs + + +class PollStatusOperationInstantiationTests(PollStatusOperationTestConfig): + @pytest.mark.it( + "Initializes 'operation_id' attribute with the provided 'operation_id' parameter" + ) + def test_operation_id(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.operation_id == init_kwargs["operation_id"] + + @pytest.mark.it( + "Initializes 'request_payload' attribute with the provided 'request_payload' parameter" + ) + def test_request_payload(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.request_payload == init_kwargs["request_payload"] + + @pytest.mark.it("Initializes 'retry_after_timer' attribute to None") + def test_retry_after_timer(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.retry_after_timer is None + + @pytest.mark.it("Initializes 'polling_timer' attribute to None") + def test_polling_timer(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.polling_timer is None + + @pytest.mark.it("Initializes 'provisioning_timeout_timer' attribute to None") + def test_provisioning_timeout_timer(self, cls_type, init_kwargs): + op = cls_type(**init_kwargs) + assert op.provisioning_timeout_timer is None + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_provisioning.PollStatusOperation, + op_test_config_class=PollStatusOperationTestConfig, + extended_op_instantiation_test_class=PollStatusOperationInstantiationTests, ) diff --git a/azure-iot-device/tests/provisioning/pipeline/test_pipeline_stages_provisioning.py b/azure-iot-device/tests/provisioning/pipeline/test_pipeline_stages_provisioning.py index 18f7085bb..305f01a5c 100644 --- a/azure-iot-device/tests/provisioning/pipeline/test_pipeline_stages_provisioning.py +++ b/azure-iot-device/tests/provisioning/pipeline/test_pipeline_stages_provisioning.py @@ -22,12 +22,22 @@ from tests.common.pipeline.helpers import ( all_common_ops, all_common_events, all_except, - make_mock_stage, - UnhandledException, + StageTestBase, ) from azure.iot.device.common.pipeline import pipeline_events_base -from tests.provisioning.pipeline.helpers import all_provisioning_ops, all_provisioning_events +from tests.provisioning.pipeline.helpers import all_provisioning_ops from tests.common.pipeline import pipeline_stage_test +from azure.iot.device.exceptions import ServiceError +import json +import datetime +from azure.iot.device.provisioning.models.registration_result import ( + RegistrationResult, + RegistrationState, +) +from tests.common.pipeline.helpers import StageRunOpTestBase +from azure.iot.device import exceptions +from azure.iot.device.provisioning.pipeline import constant +import threading logging.basicConfig(level=logging.DEBUG) @@ -46,9 +56,22 @@ fake_provisioning_host = "hogwarts.com" fake_id_scope = "weasley_wizard_wheezes" fake_ca_cert = "fake_ca_cert" fake_sas_token = "horcrux_token" +fake_request_id = "Request1234" +fake_operation_id = "Operation4567" +fake_status = "Flying" +fake_assigned_hub = "Dumbledore'sArmy" +fake_sub_status = "FlyingOnHippogriff" +fake_created_dttm = datetime.datetime(2020, 5, 17) +fake_last_update_dttm = datetime.datetime(2020, 10, 17) +fake_etag = "HighQualityFlyingBroom" +fake_payload = "petrificus totalus" +fake_symmetric_key = "Zm9vYmFy" +fake_x509_cert_file = "fantastic_beasts" +fake_x509_cert_key_file = "where_to_find_them" +fake_pass_phrase = "alohomora" -pipeline_stage_test.add_base_pipeline_stage_tests( +pipeline_stage_test.add_base_pipeline_stage_tests_old( cls=pipeline_stages_provisioning.UseSecurityClientStage, module=this_module, all_ops=all_common_ops + all_provisioning_ops, @@ -56,18 +79,32 @@ pipeline_stage_test.add_base_pipeline_stage_tests( pipeline_ops_provisioning.SetSymmetricKeySecurityClientOperation, pipeline_ops_provisioning.SetX509SecurityClientOperation, ], - all_events=all_common_events + all_provisioning_events, + all_events=all_common_events, handled_events=[], ) -fake_symmetric_key = "Zm9vYmFy" -fake_x509_cert_file = "fantastic_beasts" -fake_x509_cert_key_file = "where_to_find_them" -fake_pass_phrase = "alohomora" +pipeline_stage_test.add_base_pipeline_stage_tests_old( + cls=pipeline_stages_provisioning.RegistrationStage, + module=this_module, + all_ops=all_common_ops + all_provisioning_ops, + handled_ops=[pipeline_ops_provisioning.RegisterOperation], + all_events=all_common_events, + handled_events=[], +) -def create_x509_security_client(): +pipeline_stage_test.add_base_pipeline_stage_tests_old( + cls=pipeline_stages_provisioning.PollingStatusStage, + module=this_module, + all_ops=all_common_ops + all_provisioning_ops, + handled_ops=[pipeline_ops_provisioning.PollStatusOperation], + all_events=all_common_events, + handled_events=[], +) + + +def make_mock_x509_security_client(): mock_x509 = X509(fake_x509_cert_file, fake_x509_cert_key_file, fake_pass_phrase) return X509SecurityClient( provisioning_host=fake_provisioning_host, @@ -77,7 +114,7 @@ def create_x509_security_client(): ) -def create_symmetric_security_client(): +def make_mock_symmetric_security_client(): return SymmetricKeySecurityClient( provisioning_host=fake_provisioning_host, registration_id=fake_registration_id, @@ -86,155 +123,1061 @@ def create_symmetric_security_client(): ) -different_security_ops = [ - { - "name": "set symmetric key security", - "current_op_class": pipeline_ops_provisioning.SetSymmetricKeySecurityClientOperation, - "security_client_function_name": create_symmetric_security_client, - }, - { - "name": "set x509 security", - "current_op_class": pipeline_ops_provisioning.SetX509SecurityClientOperation, - "security_client_function_name": create_x509_security_client, - }, -] +class FakeRegistrationResult(object): + def __init__(self, operation_id, status, state): + self.operationId = operation_id + self.status = status + self.registrationState = state + + def __str__(self): + return "\n".join([str(self.registrationState), self.status]) -@pytest.fixture -def security_stage(mocker): - return make_mock_stage(mocker, pipeline_stages_provisioning.UseSecurityClientStage) +class FakeRegistrationState(object): + def __init__(self, payload): + self.deviceId = fake_device_id + self.assignedHub = fake_assigned_hub + self.payload = payload + self.substatus = fake_sub_status + + def __str__(self): + return "\n".join( + [self.deviceId, self.assignedHub, self.substatus, self.get_payload_string()] + ) + + def get_payload_string(self): + return json.dumps(self.payload, default=lambda o: o.__dict__, sort_keys=True) -@pytest.fixture -def set_security_client(callback, params_security_ops): - # Create new security client every time to pass into fixture to avoid re-use of old security client - # Otherwise the exception/failure raised by one test is makes the next test fail. - op = params_security_ops["current_op_class"]( - security_client=params_security_ops["security_client_function_name"]() - ) - op.callback = callback - return op +def create_registration_result(fake_payload, status): + state = FakeRegistrationState(payload=fake_payload) + return FakeRegistrationResult(fake_operation_id, status, state) -@pytest.mark.parametrize( - "params_security_ops", - different_security_ops, - ids=[x["current_op_class"].__name__ for x in different_security_ops], +def get_registration_result_as_bytes(registration_result): + return json.dumps(registration_result, default=lambda o: o.__dict__).encode("utf-8") + + +################### +# COMMON FIXTURES # +################### + + +@pytest.fixture(params=[True, False], ids=["With error", "No error"]) +def op_error(request, arbitrary_exception): + if request.param: + return arbitrary_exception + else: + return None + + +############################# +# USE SECURITY CLIENT STAGE # +############################# + + +class UseSecurityClientStageTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_stages_provisioning.UseSecurityClientStage + + @pytest.fixture + def init_kwargs(self): + return {} + + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + return stage + + +pipeline_stage_test.add_base_pipeline_stage_tests( + test_module=this_module, + stage_class_under_test=pipeline_stages_provisioning.UseSecurityClientStage, + stage_test_config_class=UseSecurityClientStageTestConfig, ) + + @pytest.mark.describe( - "UseSecurityClientStage run_op function with SetProvisioningClientConnectionArgsOperation operations" + "UseSecurityClientStage - .run_op() -- Called with SetSymmetricKeySecurityClientOperation" ) -class TestUseSymmetricKeyOrX509SecurityClientRunOpWithSetSecurityClient(object): - @pytest.mark.it("runs SetProvisioningClientConnectionArgsOperation op on the next stage") - def test_runs_set_security_client_args(self, mocker, security_stage, set_security_client): - security_stage.next._execute_op = mocker.Mock() - security_stage.run_op(set_security_client) - assert security_stage.next._execute_op.call_count == 1 - set_args = security_stage.next._execute_op.call_args[0][0] - assert isinstance( - set_args, pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation +class TestUseSecurityClientStageRunOpWithSetSymmetricKeySecurityClientOperation( + StageRunOpTestBase, UseSecurityClientStageTestConfig +): + @pytest.fixture + def op(self, mocker): + security_client = SymmetricKeySecurityClient( + provisioning_host="hogwarts.com", + registration_id="registered_remembrall", + id_scope="weasley_wizard_wheezes", + symmetric_key="Zm9vYmFy", + ) + security_client.get_current_sas_token = mocker.MagicMock() + return pipeline_ops_provisioning.SetSymmetricKeySecurityClientOperation( + security_client=security_client, callback=mocker.MagicMock() ) @pytest.mark.it( - "Calls the SetSecurityClient callback with the SetProvisioningClientConnectionArgsOperation error" - "when the SetProvisioningClientConnectionArgsOperation op raises an Exception" + "Sends a new SetProvisioningClientConnectionArgsOperation op down the pipeline, containing connection info from the op's security client" ) - def test_set_security_client_raises_exception( - self, mocker, security_stage, fake_exception, set_security_client - ): - security_stage.next._execute_op = mocker.Mock(side_effect=fake_exception) - security_stage.run_op(set_security_client) - assert_callback_failed(op=set_security_client, error=fake_exception) + def test_send_new_op_down(self, mocker, op, stage): + stage.run_op(op) + + # A SetProvisioningClientConnectionArgsOperation has been sent down the pipeline + stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance( + new_op, pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation + ) + + # The SetProvisioningClientConnectionArgsOperation has details from the security client + assert new_op.provisioning_host == op.security_client.provisioning_host + assert new_op.registration_id == op.security_client.registration_id + assert new_op.id_scope == op.security_client.id_scope + assert new_op.sas_token == op.security_client.get_current_sas_token.return_value + assert new_op.client_cert is None @pytest.mark.it( - "Allows any BaseExceptions raised by SetProvisioningClientConnectionArgsOperation operations to propagate" + "Completes the original SetSymmetricKeySecurityClientOperation with the same status as the new SetProvisioningClientConnectionArgsOperation, if the new SetProvisioningClientConnectionArgsOperation is completed" ) - def test_set_security_client_raises_base_exception( - self, mocker, security_stage, fake_base_exception, set_security_client - ): - security_stage.next._execute_op = mocker.Mock(side_effect=fake_base_exception) - with pytest.raises(UnhandledException): - security_stage.run_op(set_security_client) + def test_new_op_completes_success(self, mocker, op, stage, op_error): + stage.run_op(op) + stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance( + new_op, pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation + ) + + assert not op.completed + assert not new_op.completed + + new_op.complete(error=op_error) + + assert new_op.completed + assert new_op.error is op_error + assert op.completed + assert op.error is op_error + + +@pytest.mark.describe( + "UseSecurityClientStage - .run_op() -- Called with SetX509SecurityClientOperation" +) +class TestUseSecurityClientStageRunOpWithSetX509SecurityClientOperation( + StageRunOpTestBase, UseSecurityClientStageTestConfig +): + @pytest.fixture + def op(self, mocker): + x509 = X509(cert_file="fake_cert.txt", key_file="fake_key.txt", pass_phrase="alohomora") + security_client = X509SecurityClient( + provisioning_host="hogwarts.com", + registration_id="registered_remembrall", + id_scope="weasley_wizard_wheezes", + x509=x509, + ) + security_client.get_x509_certificate = mocker.MagicMock() + return pipeline_ops_provisioning.SetX509SecurityClientOperation( + security_client=security_client, callback=mocker.MagicMock() + ) @pytest.mark.it( - "Retrieves sas_token or x509_client_cert on the security_client and passes the result as the attribute of the next operation" + "Sends a new SetProvisioningClientConnectionArgsOperation op down the pipeline, containing connection info from the op's security client" ) - def test_calls_get_current_sas_token_or_get_x509_certificate( - self, mocker, security_stage, set_security_client, params_security_ops - ): - if ( - params_security_ops["current_op_class"].__name__ - == "SetSymmetricKeySecurityClientOperation" - ): - spy_method = mocker.spy(set_security_client.security_client, "get_current_sas_token") - elif params_security_ops["current_op_class"].__name__ == "SetX509SecurityClientOperation": - spy_method = mocker.spy(set_security_client.security_client, "get_x509_certificate") + def test_send_new_op_down(self, mocker, op, stage): + stage.run_op(op) - security_stage.run_op(set_security_client) - assert spy_method.call_count == 1 + # A SetProvisioningClientConnectionArgsOperation has been sent down the pipeline + stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance( + new_op, pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation + ) - set_connection_args_op = security_stage.next._execute_op.call_args[0][0] - - if ( - params_security_ops["current_op_class"].__name__ - == "SetSymmetricKeySecurityClientOperation" - ): - assert "SharedAccessSignature" in set_connection_args_op.sas_token - assert "skn=registration" in set_connection_args_op.sas_token - assert fake_id_scope in set_connection_args_op.sas_token - assert fake_registration_id in set_connection_args_op.sas_token - - elif params_security_ops["current_op_class"].__name__ == "SetX509SecurityClientOperation": - assert set_connection_args_op.client_cert.certificate_file == fake_x509_cert_file - assert set_connection_args_op.client_cert.key_file == fake_x509_cert_key_file - assert set_connection_args_op.client_cert.pass_phrase == fake_pass_phrase + # The SetProvisioningClientConnectionArgsOperation has details from the security client + assert new_op.provisioning_host == op.security_client.provisioning_host + assert new_op.registration_id == op.security_client.registration_id + assert new_op.id_scope == op.security_client.id_scope + assert new_op.client_cert == op.security_client.get_x509_certificate.return_value + assert new_op.sas_token is None @pytest.mark.it( - "Calls the callback of setting security client with no error when the next operation of " - "etting token or setting client_cert operation succeeds" + "Completes the original SetX509SecurityClientOperation with the same status as the new SetProvisioningClientConnectionArgsOperation, if the new SetProvisioningClientConnectionArgsOperation is completed" ) - def test_returns_success_if_set_sas_token_or_set_client_client_cert_succeeds( - self, security_stage, set_security_client - ): - security_stage.run_op(set_security_client) - assert_callback_succeeded(op=set_security_client) + def test_new_op_completes_success(self, mocker, op, stage, op_error): + stage.run_op(op) + stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance( + new_op, pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation + ) + + assert not op.completed + assert not new_op.completed + + new_op.complete(error=op_error) + + assert new_op.completed + assert new_op.error is op_error + assert op.completed + assert op.error is op_error + + +############################### +# REGISTRATION STAGE # +############################### + + +class RegistrationStageConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_stages_provisioning.RegistrationStage + + @pytest.fixture + def init_kwargs(self): + return {} + + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + return stage + + +pipeline_stage_test.add_base_pipeline_stage_tests( + test_module=this_module, + stage_class_under_test=pipeline_stages_provisioning.RegistrationStage, + stage_test_config_class=RegistrationStageConfig, +) + + +@pytest.mark.parametrize( + "request_payload", + [pytest.param(" ", id="empty payload"), pytest.param(fake_payload, id="some payload")], +) +@pytest.mark.describe("RegistrationStage - .run_op() -- called with RegisterOperation") +class TestRegistrationStageWithRegisterOperation(StageRunOpTestBase, RegistrationStageConfig): + @pytest.fixture + def op(self, stage, mocker, request_payload): + op = pipeline_ops_provisioning.RegisterOperation( + request_payload, fake_registration_id, callback=mocker.MagicMock() + ) + return op + + @pytest.fixture + def request_body(self, request_payload): + return '{{"payload": {json_payload}, "registrationId": "{reg_id}"}}'.format( + reg_id=fake_registration_id, json_payload=json.dumps(request_payload) + ) @pytest.mark.it( - "Returns error when get_current_sas_token or get_x509_certificate raises an exception" + "Sends a new RequestAndResponseOperation down the pipeline, configured to request a registration from provisioning service" ) - def test_get_current_sas_token_or_get_x509_certificate_raises_exception( - self, mocker, fake_exception, security_stage, set_security_client, params_security_ops - ): - if ( - params_security_ops["current_op_class"].__name__ - == "SetSymmetricKeySecurityClientOperation" - ): - set_security_client.security_client.get_current_sas_token = mocker.Mock( - side_effect=fake_exception - ) - elif params_security_ops["current_op_class"].__name__ == "SetX509SecurityClientOperation": - set_security_client.security_client.get_x509_certificate = mocker.Mock( - side_effect=fake_exception - ) - security_stage.run_op(set_security_client) - assert_callback_failed(op=set_security_client, error=fake_exception) + def test_request_and_response_op(self, stage, op, request_body): + stage.run_op(op) + + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_base.RequestAndResponseOperation) + assert new_op.request_type == "register" + assert new_op.method == "PUT" + assert new_op.resource_location == "/" + assert new_op.request_body == request_body + + +@pytest.mark.describe("RegistrationStage - .run_op() -- Called with other arbitrary operation") +class TestRegistrationStageWithArbitraryOperation(StageRunOpTestBase, RegistrationStageConfig): + @pytest.fixture + def op(self, arbitrary_op): + return arbitrary_op + + @pytest.mark.it("Sends the operation down the pipeline") + def test_sends_op_down(self, mocker, stage, op): + stage.run_op(op) + + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) + + +@pytest.mark.describe( + "RegistrationStage - EVENT: RequestAndResponseOperation created from RegisterOperation is completed" +) +@pytest.mark.parametrize( + "request_payload", + [pytest.param(" ", id="empty payload"), pytest.param(fake_payload, id="some payload")], +) +class TestRegistrationStageWithRegisterOperationCompleted(RegistrationStageConfig): + @pytest.fixture + def send_registration_op(self, mocker, request_payload): + op = pipeline_ops_provisioning.RegisterOperation( + request_payload, fake_registration_id, callback=mocker.MagicMock() + ) + return op + + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs, send_registration_op): + stage = cls_type(**init_kwargs) + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + # Run the registration operation + stage.run_op(send_registration_op) + return stage + + @pytest.fixture + def request_and_response_op(self, stage): + assert stage.send_op_down.call_count == 1 + op = stage.send_op_down.call_args[0][0] + assert isinstance(op, pipeline_ops_base.RequestAndResponseOperation) + # reset the stage mock for convenience + stage.send_op_down.reset_mock() + return op + + @pytest.fixture + def request_body(self, request_payload): + return '{{"payload": {json_payload}, "registrationId": "{reg_id}"}}'.format( + reg_id=fake_registration_id, json_payload=json.dumps(request_payload) + ) @pytest.mark.it( - "Allows any BaseExceptions raised by get_current_sas_token or get_x509_certificate to propagate" + "Completes the RegisterOperation unsuccessfully, with the error from the RequestAndResponseOperation, if the RequestAndResponseOperation is completed unsuccessfully" ) - def test_get_current_sas_token_get_x509_certificate_raises_base_exception( - self, mocker, fake_base_exception, security_stage, set_security_client, params_security_ops + @pytest.mark.parametrize( + "status_code", + [ + pytest.param(None, id="Status Code: None"), + pytest.param(200, id="Status Code: 200"), + pytest.param(300, id="Status Code: 300"), + pytest.param(400, id="Status Code: 400"), + pytest.param(500, id="Status Code: 500"), + ], + ) + @pytest.mark.parametrize( + "has_response_body", [True, False], ids=["With Response Body", "No Response Body"] + ) + def test_request_and_response_op_completed_with_err( + self, + stage, + send_registration_op, + request_and_response_op, + status_code, + has_response_body, + arbitrary_exception, ): - if ( - params_security_ops["current_op_class"].__name__ - == "SetSymmetricKeySecurityClientOperation" - ): - set_security_client.security_client.get_current_sas_token = mocker.Mock( - side_effect=fake_base_exception - ) - elif params_security_ops["current_op_class"].__name__ == "SetX509SecurityClientOperation": - set_security_client.security_client.get_x509_certificate = mocker.Mock( - side_effect=fake_base_exception - ) - with pytest.raises(UnhandledException): - security_stage.run_op(set_security_client) + assert not send_registration_op.completed + assert not request_and_response_op.completed + + # NOTE: It shouldn't happen that an operation completed with error has a status code or a + # response body, but it IS possible. + request_and_response_op.status_code = status_code + if has_response_body: + request_and_response_op.response_body = b'{"key": "value"}' + request_and_response_op.complete(error=arbitrary_exception) + + assert request_and_response_op.completed + assert request_and_response_op.error is arbitrary_exception + assert send_registration_op.completed + assert send_registration_op.error is arbitrary_exception + assert send_registration_op.registration_result is None + + @pytest.mark.it( + "Completes the RegisterOperation unsuccessfully with a ServiceError if the RequestAndResponseOperation is completed with a status code >= 300 and less than 429" + ) + @pytest.mark.parametrize( + "has_response_body", [True, False], ids=["With Response Body", "No Response Body"] + ) + @pytest.mark.parametrize( + "status_code", + [ + pytest.param(300, id="Status Code: 300"), + pytest.param(400, id="Status Code: 400"), + pytest.param(428, id="Status Code: 428"), + ], + ) + def test_request_and_response_op_completed_success_with_bad_code( + self, stage, send_registration_op, request_and_response_op, status_code, has_response_body + ): + assert not send_registration_op.completed + assert not request_and_response_op.completed + + request_and_response_op.status_code = status_code + if has_response_body: + request_and_response_op.response_body = b'{"key": "value"}' + request_and_response_op.complete() + + assert request_and_response_op.completed + assert request_and_response_op.error is None + assert send_registration_op.completed + assert isinstance(send_registration_op.error, ServiceError) + # Twin is NOT returned + assert send_registration_op.registration_result is None + + @pytest.mark.it( + "Decodes, deserializes, and returns registration_result on the RegisterOperation op when RequestAndResponseOperation completes with no error if the status code < 300 and if status is 'assigned'" + ) + def test_request_and_response_op_completed_success_with_status_assigned( + self, stage, request_payload, send_registration_op, request_and_response_op + ): + registration_result = create_registration_result(request_payload, "assigned") + + assert not send_registration_op.completed + assert not request_and_response_op.completed + + request_and_response_op.status_code = 200 + request_and_response_op.retry_after = None + request_and_response_op.response_body = get_registration_result_as_bytes( + registration_result + ) + request_and_response_op.complete() + + assert request_and_response_op.completed + assert request_and_response_op.error is None + assert send_registration_op.completed + assert send_registration_op.error is None + # We need to assert string representations as these are inherently different objects + assert str(send_registration_op.registration_result) == str(registration_result) + + @pytest.mark.it( + "Decodes, deserializes, and returns registration_result along with an error on the RegisterOperation op when RequestAndResponseOperation completes with status code < 300 and status 'failed'" + ) + def test_request_and_response_op_completed_success_with_status_failed( + self, stage, request_payload, send_registration_op, request_and_response_op + ): + registration_result = create_registration_result(request_payload, "failed") + + assert not send_registration_op.completed + assert not request_and_response_op.completed + + request_and_response_op.status_code = 200 + request_and_response_op.retry_after = None + request_and_response_op.response_body = get_registration_result_as_bytes( + registration_result + ) + request_and_response_op.complete() + + assert request_and_response_op.completed + assert request_and_response_op.error is None + assert send_registration_op.completed + assert isinstance(send_registration_op.error, ServiceError) + # We need to assert string representations as these are inherently different objects + assert str(send_registration_op.registration_result) == str(registration_result) + assert "failed registration status" in str(send_registration_op.error) + + @pytest.mark.it( + "Returns error on the RegisterOperation op when RequestAndResponseOperation completes with status code < 300 and some unknown status" + ) + def test_request_and_response_op_completed_success_with_unknown_status( + self, stage, request_payload, send_registration_op, request_and_response_op + ): + registration_result = create_registration_result(request_payload, "quidditching") + + assert not send_registration_op.completed + assert not request_and_response_op.completed + + request_and_response_op.status_code = 200 + request_and_response_op.retry_after = None + request_and_response_op.response_body = get_registration_result_as_bytes( + registration_result + ) + request_and_response_op.complete() + + assert request_and_response_op.completed + assert request_and_response_op.error is None + assert send_registration_op.completed + assert isinstance(send_registration_op.error, ServiceError) + assert "invalid registration status" in str(send_registration_op.error) + + @pytest.mark.it( + "Decodes, deserializes the response from RequestAndResponseOperation and creates another op if the status code < 300 and if status is 'assigning'" + ) + def test_spawns_another_op_request_and_response_op_completed_success_with_status_assigning( + self, mocker, stage, request_payload, send_registration_op, request_and_response_op + ): + mock_timer = mocker.patch( + "azure.iot.device.provisioning.pipeline.pipeline_stages_provisioning.Timer" + ) + + mocker.spy(send_registration_op, "spawn_worker_op") + registration_result = create_registration_result(request_payload, "assigning") + + assert not send_registration_op.completed + assert not request_and_response_op.completed + + request_and_response_op.status_code = 200 + request_and_response_op.retry_after = None + request_and_response_op.response_body = get_registration_result_as_bytes( + registration_result + ) + request_and_response_op.complete() + + assert send_registration_op.retry_after_timer is None + assert send_registration_op.polling_timer is not None + timer_callback = mock_timer.call_args[0][1] + timer_callback() + + assert request_and_response_op.completed + assert request_and_response_op.error is None + assert not send_registration_op.completed + assert send_registration_op.error is None + assert ( + send_registration_op.spawn_worker_op.call_args[1]["operation_id"] == fake_operation_id + ) + + +class RetryStageConfig(object): + @pytest.fixture + def init_kwargs(self): + return {} + + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + mocker.spy(stage, "run_op") + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + return stage + + +@pytest.mark.describe("RegistrationStage - .run_op() -- retried again with RegisterOperation") +@pytest.mark.parametrize( + "request_payload", + [pytest.param(" ", id="empty payload"), pytest.param(fake_payload, id="some payload")], +) +class TestRegistrationStageWithRetryOfRegisterOperation(RetryStageConfig): + @pytest.fixture + def cls_type(self): + return pipeline_stages_provisioning.RegistrationStage + + @pytest.fixture + def op(self, stage, mocker, request_payload): + return pipeline_ops_provisioning.RegisterOperation( + request_payload, fake_registration_id, callback=mocker.MagicMock() + ) + + @pytest.fixture + def request_body(self, request_payload): + return '{{"payload": {json_payload}, "registrationId": "{reg_id}"}}'.format( + reg_id=fake_registration_id, json_payload=json.dumps(request_payload) + ) + + @pytest.mark.it( + "Decodes, deserializes the response from RequestAndResponseOperation and retries the op if the status code > 429" + ) + def test_stage_retries_op_if_next_stage_responds_with_status_code_greater_than_429( + self, mocker, stage, op, request_body, request_payload + ): + mock_timer = mocker.patch( + "azure.iot.device.provisioning.pipeline.pipeline_stages_provisioning.Timer" + ) + + stage.run_op(op) + assert stage.send_op_down.call_count == 1 + next_op = stage.send_op_down.call_args[0][0] + assert isinstance(next_op, pipeline_ops_base.RequestAndResponseOperation) + + next_op.status_code = 430 + next_op.retry_after = "1" + registration_result = create_registration_result(request_payload, "flying") + next_op.response_body = get_registration_result_as_bytes(registration_result) + next_op.complete() + + assert op.retry_after_timer is not None + assert op.polling_timer is None + timer_callback = mock_timer.call_args[0][1] + timer_callback() + + assert stage.run_op.call_count == 2 + assert stage.send_op_down.call_count == 2 + + next_op_2 = stage.send_op_down.call_args[0][0] + assert isinstance(next_op_2, pipeline_ops_base.RequestAndResponseOperation) + assert next_op_2.request_type == "register" + assert next_op_2.method == "PUT" + assert next_op_2.resource_location == "/" + assert next_op_2.request_body == request_body + + +@pytest.mark.describe( + "RegistrationStage - .run_op() -- Called with register request operation eligible for timeout" +) +class TestRegistrationStageWithTimeoutOfRegisterOperation( + StageRunOpTestBase, RegistrationStageConfig +): + @pytest.fixture + def op(self, stage, mocker): + op = pipeline_ops_provisioning.RegisterOperation( + " ", fake_registration_id, callback=mocker.MagicMock() + ) + return op + + @pytest.fixture + def mock_timer(self, mocker): + return mocker.patch( + "azure.iot.device.provisioning.pipeline.pipeline_stages_provisioning.Timer" + ) + + @pytest.mark.it( + "Adds a provisioning timeout timer with the interval specified in the configuration to the operation, and starts it" + ) + def test_adds_timer(self, mocker, stage, op, mock_timer): + stage.run_op(op) + + assert mock_timer.call_count == 1 + assert mock_timer.call_args == mocker.call(constant.DEFAULT_TIMEOUT_INTERVAL, mocker.ANY) + assert op.provisioning_timeout_timer is mock_timer.return_value + assert op.provisioning_timeout_timer.start.call_count == 1 + assert op.provisioning_timeout_timer.start.call_args == mocker.call() + + @pytest.mark.it( + "Sends converted RequestResponse Op down the pipeline after attaching timer to the original op" + ) + def test_sends_down(self, mocker, stage, op, mock_timer): + stage.run_op(op) + + assert stage.send_op_down.call_count == 1 + + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_base.RequestAndResponseOperation) + + assert op.provisioning_timeout_timer is mock_timer.return_value + + @pytest.mark.it("Completes the operation unsuccessfully, with a ServiceError due to timeout") + def test_not_complete_timeout(self, mocker, stage, op, mock_timer): + # Apply the timer + stage.run_op(op) + assert not op.completed + assert mock_timer.call_count == 1 + on_timer_complete = mock_timer.call_args[0][1] + + # Call timer complete callback (indicating timer completion) + on_timer_complete() + + # Op is now completed with error + assert op.completed + assert isinstance(op.error, exceptions.ServiceError) + assert "register" in op.error.args[0] + + @pytest.mark.it( + "Completes the operation successfully, cancels and clears the operation's timeout timer" + ) + def test_complete_before_timeout(self, mocker, stage, op, mock_timer): + # Apply the timer + stage.run_op(op) + assert not op.completed + assert mock_timer.call_count == 1 + mock_timer_inst = op.provisioning_timeout_timer + assert mock_timer_inst is mock_timer.return_value + assert mock_timer_inst.cancel.call_count == 0 + + # Complete the next operation + new_op = stage.send_op_down.call_args[0][0] + new_op.complete() + + # Timer is now cancelled and cleared + assert mock_timer_inst.cancel.call_count == 1 + assert mock_timer_inst.cancel.call_args == mocker.call() + assert op.provisioning_timeout_timer is None + + +class PollingStageConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_stages_provisioning.PollingStatusStage + + @pytest.fixture + def init_kwargs(self): + return {} + + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + return stage + + +pipeline_stage_test.add_base_pipeline_stage_tests( + test_module=this_module, + stage_class_under_test=pipeline_stages_provisioning.PollingStatusStage, + stage_test_config_class=PollingStageConfig, +) + + +@pytest.mark.describe("PollingStatusStage - .run_op() -- called with PollStatusOperation") +class TestPollingStatusStageWithPollStatusOperation(StageRunOpTestBase, PollingStageConfig): + @pytest.fixture + def op(self, stage, mocker): + op = pipeline_ops_provisioning.PollStatusOperation( + fake_operation_id, " ", callback=mocker.MagicMock() + ) + return op + + @pytest.mark.it( + "Sends a new RequestAndResponseOperation down the pipeline, configured to request a registration from provisioning service" + ) + def test_request_and_response_op(self, stage, op): + stage.run_op(op) + + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_base.RequestAndResponseOperation) + assert new_op.request_type == "query" + assert new_op.method == "GET" + assert new_op.resource_location == "/" + assert new_op.request_body == " " + + +@pytest.mark.describe("PollingStatusStage - .run_op() -- Called with other arbitrary operation") +class TestPollingStatusStageWithArbitraryOperation(StageRunOpTestBase, PollingStageConfig): + @pytest.fixture + def op(self, arbitrary_op): + return arbitrary_op + + @pytest.mark.it("Sends the operation down the pipeline") + def test_sends_op_down(self, mocker, stage, op): + stage.run_op(op) + + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) + + +@pytest.mark.describe( + "PollingStatusStage - EVENT: RequestAndResponseOperation created from PollStatusOperation is completed" +) +class TestPollingStatusStageWithPollStatusOperationCompleted(PollingStageConfig): + @pytest.fixture + def send_query_op(self, mocker): + op = pipeline_ops_provisioning.PollStatusOperation( + fake_operation_id, " ", callback=mocker.MagicMock() + ) + return op + + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs, send_query_op): + stage = cls_type(**init_kwargs) + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + # Run the registration operation + stage.run_op(send_query_op) + return stage + + @pytest.fixture + def request_and_response_op(self, stage): + assert stage.send_op_down.call_count == 1 + op = stage.send_op_down.call_args[0][0] + assert isinstance(op, pipeline_ops_base.RequestAndResponseOperation) + # reset the stage mock for convenience + stage.send_op_down.reset_mock() + return op + + @pytest.mark.it( + "Completes the PollStatusOperation unsuccessfully, with the error from the RequestAndResponseOperation, if the RequestAndResponseOperation is completed unsuccessfully" + ) + @pytest.mark.parametrize( + "status_code", + [ + pytest.param(None, id="Status Code: None"), + pytest.param(200, id="Status Code: 200"), + pytest.param(300, id="Status Code: 300"), + pytest.param(400, id="Status Code: 400"), + pytest.param(500, id="Status Code: 500"), + ], + ) + @pytest.mark.parametrize( + "has_response_body", [True, False], ids=["With Response Body", "No Response Body"] + ) + def test_request_and_response_op_completed_with_err( + self, + stage, + send_query_op, + request_and_response_op, + status_code, + has_response_body, + arbitrary_exception, + ): + assert not send_query_op.completed + assert not request_and_response_op.completed + + # NOTE: It shouldn't happen that an operation completed with error has a status code or a + # response body, but it IS possible. + request_and_response_op.status_code = status_code + if has_response_body: + request_and_response_op.response_body = b'{"key": "value"}' + request_and_response_op.complete(error=arbitrary_exception) + + assert request_and_response_op.completed + assert request_and_response_op.error is arbitrary_exception + assert send_query_op.completed + assert send_query_op.error is arbitrary_exception + assert send_query_op.registration_result is None + + @pytest.mark.it( + "Completes the PollStatusOperation unsuccessfully with a ServiceError if the RequestAndResponseOperation is completed with a status code >= 300 and less than 429" + ) + @pytest.mark.parametrize( + "has_response_body", [True, False], ids=["With Response Body", "No Response Body"] + ) + @pytest.mark.parametrize( + "status_code", + [ + pytest.param(300, id="Status Code: 300"), + pytest.param(400, id="Status Code: 400"), + pytest.param(428, id="Status Code: 428"), + ], + ) + def test_request_and_response_op_completed_success_with_bad_code( + self, stage, send_query_op, request_and_response_op, status_code, has_response_body + ): + assert not send_query_op.completed + assert not request_and_response_op.completed + + request_and_response_op.status_code = status_code + if has_response_body: + request_and_response_op.response_body = b'{"key": "value"}' + request_and_response_op.complete() + + assert request_and_response_op.completed + assert request_and_response_op.error is None + assert send_query_op.completed + assert isinstance(send_query_op.error, ServiceError) + # Twin is NOT returned + assert send_query_op.registration_result is None + + @pytest.mark.it( + "Decodes, deserializes, and returns registration_result on the PollStatusOperation op when RequestAndResponseOperation completes with no error if the status code < 300 and if status is 'assigned'" + ) + def test_request_and_response_op_completed_success_with_status_assigned( + self, stage, send_query_op, request_and_response_op + ): + registration_result = create_registration_result(" ", "assigned") + + assert not send_query_op.completed + assert not request_and_response_op.completed + + request_and_response_op.status_code = 200 + request_and_response_op.retry_after = None + request_and_response_op.response_body = get_registration_result_as_bytes( + registration_result + ) + request_and_response_op.complete() + + assert request_and_response_op.completed + assert request_and_response_op.error is None + assert send_query_op.completed + assert send_query_op.error is None + # We need to assert string representations as these are inherently different objects + assert str(send_query_op.registration_result) == str(registration_result) + + @pytest.mark.it( + "Decodes, deserializes, and returns registration_result along with an error on the PollStatusOperation op when RequestAndResponseOperation completes with status code < 300 and status 'failed'" + ) + def test_request_and_response_op_completed_success_with_status_failed( + self, stage, send_query_op, request_and_response_op + ): + registration_result = create_registration_result(" ", "failed") + + assert not send_query_op.completed + assert not request_and_response_op.completed + + request_and_response_op.status_code = 200 + request_and_response_op.retry_after = None + request_and_response_op.response_body = get_registration_result_as_bytes( + registration_result + ) + request_and_response_op.complete() + + assert request_and_response_op.completed + assert request_and_response_op.error is None + assert send_query_op.completed + assert isinstance(send_query_op.error, ServiceError) + # We need to assert string representations as these are inherently different objects + assert str(send_query_op.registration_result) == str(registration_result) + assert "failed registration status" in str(send_query_op.error) + + @pytest.mark.it( + "Returns error on the PollStatusOperation op when RequestAndResponseOperation completes with status code < 300 and some unknown status" + ) + def test_request_and_response_op_completed_success_with_unknown_status( + self, stage, send_query_op, request_and_response_op + ): + registration_result = create_registration_result(" ", "quidditching") + + assert not send_query_op.completed + assert not request_and_response_op.completed + + request_and_response_op.status_code = 200 + request_and_response_op.retry_after = None + request_and_response_op.response_body = get_registration_result_as_bytes( + registration_result + ) + request_and_response_op.complete() + + assert request_and_response_op.completed + assert request_and_response_op.error is None + assert send_query_op.completed + assert isinstance(send_query_op.error, ServiceError) + assert "invalid registration status" in str(send_query_op.error) + + +@pytest.mark.describe("PollingStatusStage - .run_op() -- retried again with PollStatusOperation") +class TestPollingStatusStageWithPollStatusRetryOperation(RetryStageConfig): + @pytest.fixture + def cls_type(self): + return pipeline_stages_provisioning.PollingStatusStage + + @pytest.fixture + def op(self, stage, mocker): + op = pipeline_ops_provisioning.PollStatusOperation( + fake_operation_id, " ", callback=mocker.MagicMock() + ) + return op + + @pytest.mark.it( + "Decodes, deserializes the response from RequestAndResponseOperation and retries the op if the status code > 429" + ) + def test_stage_retries_op_if_next_stage_responds_with_status_code_greater_than_429( + self, mocker, stage, op + ): + mock_timer = mocker.patch( + "azure.iot.device.provisioning.pipeline.pipeline_stages_provisioning.Timer" + ) + + stage.run_op(op) + assert stage.send_op_down.call_count == 1 + next_op = stage.send_op_down.call_args[0][0] + assert isinstance(next_op, pipeline_ops_base.RequestAndResponseOperation) + + next_op.status_code = 430 + next_op.retry_after = "1" + registration_result = create_registration_result(" ", "flying") + next_op.response_body = get_registration_result_as_bytes(registration_result) + next_op.complete() + + assert op.retry_after_timer is not None + assert op.polling_timer is None + timer_callback = mock_timer.call_args[0][1] + timer_callback() + + assert stage.run_op.call_count == 2 + assert stage.send_op_down.call_count == 2 + + next_op_2 = stage.send_op_down.call_args[0][0] + assert isinstance(next_op_2, pipeline_ops_base.RequestAndResponseOperation) + assert next_op_2.request_type == "query" + assert next_op_2.method == "GET" + assert next_op_2.resource_location == "/" + assert next_op_2.request_body == " " + + @pytest.mark.it( + "Decodes, deserializes the response from RequestAndResponseOperation and retries the op if the status code < 300 and if status is 'assigning'" + ) + def test_stage_retries_op_if_next_stage_responds_with_status_assigning(self, mocker, stage, op): + mock_timer = mocker.patch( + "azure.iot.device.provisioning.pipeline.pipeline_stages_provisioning.Timer" + ) + + stage.run_op(op) + assert stage.send_op_down.call_count == 1 + next_op = stage.send_op_down.call_args[0][0] + assert isinstance(next_op, pipeline_ops_base.RequestAndResponseOperation) + + next_op.status_code = 228 + next_op.retry_after = "1" + registration_result = create_registration_result(" ", "assigning") + next_op.response_body = get_registration_result_as_bytes(registration_result) + next_op.complete() + + assert op.retry_after_timer is None + assert op.polling_timer is not None + timer_callback = mock_timer.call_args[0][1] + timer_callback() + + assert stage.run_op.call_count == 2 + assert stage.send_op_down.call_count == 2 + + next_op_2 = stage.send_op_down.call_args[0][0] + assert isinstance(next_op_2, pipeline_ops_base.RequestAndResponseOperation) + assert next_op_2.request_type == "query" + assert next_op_2.method == "GET" + assert next_op_2.resource_location == "/" + assert next_op_2.request_body == " " + + +@pytest.mark.describe( + "RegistrationStage - .run_op() -- Called with register request operation eligible for timeout" +) +class TestPollingStageWithTimeoutOfQueryOperation(StageRunOpTestBase, PollingStageConfig): + @pytest.fixture + def op(self, stage, mocker): + op = pipeline_ops_provisioning.PollStatusOperation( + fake_operation_id, " ", callback=mocker.MagicMock() + ) + return op + + @pytest.fixture + def mock_timer(self, mocker): + return mocker.patch( + "azure.iot.device.provisioning.pipeline.pipeline_stages_provisioning.Timer" + ) + + @pytest.mark.it( + "Adds a provisioning timeout timer with the interval specified in the configuration to the operation, and starts it" + ) + def test_adds_timer(self, mocker, stage, op, mock_timer): + stage.run_op(op) + + assert mock_timer.call_count == 1 + assert mock_timer.call_args == mocker.call(constant.DEFAULT_TIMEOUT_INTERVAL, mocker.ANY) + assert op.provisioning_timeout_timer is mock_timer.return_value + assert op.provisioning_timeout_timer.start.call_count == 1 + assert op.provisioning_timeout_timer.start.call_args == mocker.call() + + @pytest.mark.it( + "Sends converted RequestResponse Op down the pipeline after attaching timer to the original op" + ) + def test_sends_down(self, mocker, stage, op, mock_timer): + stage.run_op(op) + + assert stage.send_op_down.call_count == 1 + + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_base.RequestAndResponseOperation) + + assert op.provisioning_timeout_timer is mock_timer.return_value + + @pytest.mark.it("Completes the operation unsuccessfully, with a ServiceError due to timeout") + def test_not_complete_timeout(self, mocker, stage, op, mock_timer): + # Apply the timer + stage.run_op(op) + assert not op.completed + assert mock_timer.call_count == 1 + on_timer_complete = mock_timer.call_args[0][1] + + # Call timer complete callback (indicating timer completion) + on_timer_complete() + + # Op is now completed with error + assert op.completed + assert isinstance(op.error, exceptions.ServiceError) + assert "query" in op.error.args[0] + + @pytest.mark.it( + "Completes the operation successfully, cancels and clears the operation's timeout timer" + ) + def test_complete_before_timeout(self, mocker, stage, op, mock_timer): + # Apply the timer + stage.run_op(op) + assert not op.completed + assert mock_timer.call_count == 1 + mock_timer_inst = op.provisioning_timeout_timer + assert mock_timer_inst is mock_timer.return_value + assert mock_timer_inst.cancel.call_count == 0 + + # Complete the next operation + new_op = stage.send_op_down.call_args[0][0] + new_op.complete() + + # Timer is now cancelled and cleared + assert mock_timer_inst.cancel.call_count == 1 + assert mock_timer_inst.cancel.call_args == mocker.call() + assert op.provisioning_timeout_timer is None diff --git a/azure-iot-device/tests/provisioning/pipeline/test_pipeline_stages_provisioning_mqtt.py b/azure-iot-device/tests/provisioning/pipeline/test_pipeline_stages_provisioning_mqtt.py index 6bb657934..be201e686 100644 --- a/azure-iot-device/tests/provisioning/pipeline/test_pipeline_stages_provisioning_mqtt.py +++ b/azure-iot-device/tests/provisioning/pipeline/test_pipeline_stages_provisioning_mqtt.py @@ -13,9 +13,9 @@ from azure.iot.device.common.pipeline import ( pipeline_stages_base, pipeline_ops_mqtt, pipeline_events_mqtt, + pipeline_events_base, ) from azure.iot.device.provisioning.pipeline import ( - pipeline_events_provisioning, pipeline_ops_provisioning, pipeline_stages_provisioning_mqtt, ) @@ -25,11 +25,13 @@ from tests.common.pipeline.helpers import ( all_common_ops, all_common_events, all_except, - make_mock_stage, - UnhandledException, + StageTestBase, ) -from tests.provisioning.pipeline.helpers import all_provisioning_ops, all_provisioning_events +from tests.provisioning.pipeline.helpers import all_provisioning_ops from tests.common.pipeline import pipeline_stage_test +import json +from azure.iot.device.provisioning.pipeline import constant as pipeline_constant +from azure.iot.device.product_info import ProductInfo logging.basicConfig(level=logging.DEBUG) @@ -64,158 +66,157 @@ fake_response_topic = "$dps/registrations/res/200/?$rid={}".format(fake_request_ ops_handled_by_this_stage = [ pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation, - pipeline_ops_provisioning.SendRegistrationRequestOperation, - pipeline_ops_provisioning.SendQueryRequestOperation, + pipeline_ops_base.RequestOperation, pipeline_ops_base.EnableFeatureOperation, pipeline_ops_base.DisableFeatureOperation, ] events_handled_by_this_stage = [pipeline_events_mqtt.IncomingMQTTMessageEvent] -pipeline_stage_test.add_base_pipeline_stage_tests( - cls=pipeline_stages_provisioning_mqtt.ProvisioningMQTTConverterStage, +pipeline_stage_test.add_base_pipeline_stage_tests_old( + cls=pipeline_stages_provisioning_mqtt.ProvisioningMQTTTranslationStage, module=this_module, all_ops=all_common_ops + all_provisioning_ops, handled_ops=ops_handled_by_this_stage, - all_events=all_common_events + all_provisioning_events, + all_events=all_common_events, handled_events=events_handled_by_this_stage, extra_initializer_defaults={"action_to_topic": dict}, ) -@pytest.fixture(scope="function") -def some_exception(): - return Exception("Alohomora") - - @pytest.fixture -def mock_stage(mocker): - return make_mock_stage(mocker, pipeline_stages_provisioning_mqtt.ProvisioningMQTTConverterStage) - - -@pytest.fixture -def set_security_client_args(callback): +def set_security_client_args(mocker): op = pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation( provisioning_host=fake_provisioning_host, registration_id=fake_registration_id, id_scope=fake_id_scope, sas_token=fake_sas_token, client_cert=fake_client_cert, - callback=callback, + callback=mocker.MagicMock(), ) + mocker.spy(op, "complete") return op -@pytest.fixture -def stages_configured(mock_stage, set_security_client_args, mocker): - set_security_client_args.callback = None - mock_stage.run_op(set_security_client_args) - mocker.resetall() +class ProvisioningMQTTTranslationStageTestBase(StageTestBase): + @pytest.fixture + def stage(self): + return pipeline_stages_provisioning_mqtt.ProvisioningMQTTTranslationStage() + + @pytest.fixture + def stages_configured(self, stage, stage_base_configuration, set_security_client_args, mocker): + mocker.spy(stage.pipeline_root, "handle_pipeline_event") + + stage.run_op(set_security_client_args) + mocker.resetall() @pytest.mark.describe( - "ProvisioningMQTTConverterStage run_op function with SetProvisioningClientConnectionArgsOperation" + "ProvisioningMQTTTranslationStage run_op function with SetProvisioningClientConnectionArgsOperation" ) -class TestProvisioningMQTTConverterWithSetProvisioningClientConnectionArgsOperation(object): +class TestProvisioningMQTTTranslationStageWithSetProvisioningClientConnectionArgsOperation( + ProvisioningMQTTTranslationStageTestBase +): @pytest.mark.it( "Runs a pipeline_ops_mqtt.SetMQTTConnectionArgsOperation operation on the next stage" ) - def test_runs_set_connection_args(self, mock_stage, set_security_client_args): - mock_stage.run_op(set_security_client_args) - assert mock_stage.next._execute_op.call_count == 1 - new_op = mock_stage.next._execute_op.call_args[0][0] + def test_runs_set_connection_args(self, stage, set_security_client_args): + stage.run_op(set_security_client_args) + assert stage.next._run_op.call_count == 1 + new_op = stage.next._run_op.call_args[0][0] assert isinstance(new_op, pipeline_ops_mqtt.SetMQTTConnectionArgsOperation) @pytest.mark.it( "Sets SetMQTTConnectionArgsOperation.client_id = SetProvisioningClientConnectionArgsOperation.registration_id" ) - def test_sets_client_id(self, mock_stage, set_security_client_args): - mock_stage.run_op(set_security_client_args) - new_op = mock_stage.next._execute_op.call_args[0][0] + def test_sets_client_id(self, stage, set_security_client_args): + stage.run_op(set_security_client_args) + new_op = stage.next._run_op.call_args[0][0] assert new_op.client_id == fake_registration_id @pytest.mark.it( "Sets SetMQTTConnectionArgsOperation.hostname = SetProvisioningClientConnectionArgsOperation.provisioning_host" ) - def test_sets_hostname(self, mock_stage, set_security_client_args): - mock_stage.run_op(set_security_client_args) - new_op = mock_stage.next._execute_op.call_args[0][0] + def test_sets_hostname(self, stage, set_security_client_args): + stage.run_op(set_security_client_args) + new_op = stage.next._run_op.call_args[0][0] assert new_op.hostname == fake_provisioning_host @pytest.mark.it( "Sets SetMQTTConnectionArgsOperation.client_cert = SetProvisioningClientConnectionArgsOperation.client_cert" ) - def test_sets_client_cert(self, mock_stage, set_security_client_args): - mock_stage.run_op(set_security_client_args) - new_op = mock_stage.next._execute_op.call_args[0][0] + def test_sets_client_cert(self, stage, set_security_client_args): + stage.run_op(set_security_client_args) + new_op = stage.next._run_op.call_args[0][0] assert new_op.client_cert == fake_client_cert @pytest.mark.it( "Sets SetMQTTConnectionArgsOperation.sas_token = SetProvisioningClientConnectionArgsOperation.sas_token" ) - def test_sets_sas_token(self, mock_stage, set_security_client_args): - mock_stage.run_op(set_security_client_args) - new_op = mock_stage.next._execute_op.call_args[0][0] + def test_sets_sas_token(self, stage, set_security_client_args): + stage.run_op(set_security_client_args) + new_op = stage.next._run_op.call_args[0][0] assert new_op.sas_token == fake_sas_token @pytest.mark.it( "Sets MqttConnectionArgsOperation.username = SetProvisioningClientConnectionArgsOperation.{id_scope}/registrations/{registration_id}/api-version={api_version}&ClientVersion={client_version}" ) - def test_sets_username(self, mock_stage, set_security_client_args): - mock_stage.run_op(set_security_client_args) - new_op = mock_stage.next._execute_op.call_args[0][0] + def test_sets_username(self, stage, set_security_client_args): + stage.run_op(set_security_client_args) + new_op = stage.next._run_op.call_args[0][0] assert ( new_op.username == "{id_scope}/registrations/{registration_id}/api-version={api_version}&ClientVersion={client_version}".format( id_scope=fake_id_scope, registration_id=fake_registration_id, api_version=constant.PROVISIONING_API_VERSION, - client_version=urllib.parse.quote_plus(constant.USER_AGENT), + client_version=urllib.parse.quote_plus(ProductInfo.get_provisioning_user_agent()), ) ) @pytest.mark.it( - "Calls the SetSymmetricKeySecurityClientArgs callback with error if the pipeline_ops_mqtt.SetMQTTConnectionArgsOperation operation raises an Exception" + "Completes the SetSymmetricKeySecurityClientArgs op with error if the pipeline_ops_mqtt.SetMQTTConnectionArgsOperation operation raises an Exception" ) def test_set_connection_args_raises_exception( - self, mock_stage, mocker, some_exception, set_security_client_args + self, stage, mocker, arbitrary_exception, set_security_client_args ): - mock_stage.next._execute_op = mocker.Mock(side_effect=some_exception) - mock_stage.run_op(set_security_client_args) - assert_callback_failed(op=set_security_client_args, error=some_exception) - - @pytest.mark.it( - "Allows any BaseExceptions raised inside the pipeline_ops_mqtt.SetMQTTConnectionArgsOperation operation to propagate" - ) - def test_set_connection_args_raises_base_exception( - self, mock_stage, mocker, fake_base_exception, set_security_client_args - ): - mock_stage.next._execute_op = mocker.Mock(side_effect=fake_base_exception) - with pytest.raises(UnhandledException): - mock_stage.run_op(set_security_client_args) + stage.next._run_op = mocker.Mock(side_effect=arbitrary_exception) + stage.run_op(set_security_client_args) + assert set_security_client_args.complete.call_count == 1 + assert set_security_client_args.complete.call_args == mocker.call(error=arbitrary_exception) @pytest.mark.it( "Calls the SetSymmetricKeySecurityClientArgs callback with no error if the pipeline_ops_mqtt.SetMQTTConnectionArgsOperation operation succeeds" ) def test_returns_success_if_set_connection_args_succeeds( - self, mock_stage, mocker, set_security_client_args + self, stage, mocker, set_security_client_args, next_stage_succeeds ): - mock_stage.run_op(set_security_client_args) - assert_callback_succeeded(op=set_security_client_args) + stage.run_op(set_security_client_args) + assert set_security_client_args.complete.call_count == 1 + assert set_security_client_args.complete.call_args == mocker.call(error=None) basic_ops = [ { - "op_class": pipeline_ops_provisioning.SendRegistrationRequestOperation, - "op_init_kwargs": {"request_id": fake_request_id, "request_payload": fake_mqtt_payload}, + "op_class": pipeline_ops_base.RequestOperation, + "op_init_kwargs": { + "request_id": fake_request_id, + "request_type": pipeline_constant.REGISTER, + "method": "PUT", + "resource_location": "/", + "request_body": "test payload", + }, "new_op_class": pipeline_ops_mqtt.MQTTPublishOperation, }, { - "op_class": pipeline_ops_provisioning.SendQueryRequestOperation, + "op_class": pipeline_ops_base.RequestOperation, "op_init_kwargs": { "request_id": fake_request_id, - "operation_id": fake_operation_id, - "request_payload": fake_mqtt_payload, + "request_type": pipeline_constant.QUERY, + "method": "GET", + "resource_location": "/", + "query_params": {"operation_id": fake_operation_id}, + "request_body": "test payload", }, "new_op_class": pipeline_ops_mqtt.MQTTPublishOperation, }, @@ -233,9 +234,9 @@ basic_ops = [ @pytest.fixture -def op(params, callback): - op = params["op_class"](**params["op_init_kwargs"]) - op.callback = callback +def op(params, mocker): + op = params["op_class"](callback=mocker.MagicMock(), **params["op_init_kwargs"]) + mocker.spy(op, "complete") return op @@ -244,53 +245,90 @@ def op(params, callback): basic_ops, ids=["{}->{}".format(x["op_class"].__name__, x["new_op_class"].__name__) for x in basic_ops], ) -@pytest.mark.describe("ProvisioningMQTTConverterStage basic operation tests") -class TestProvisioningMQTTConverterBasicOperations(object): +@pytest.mark.describe("ProvisioningMQTTTranslationStage basic operation tests") +class TestProvisioningMQTTTranslationStageBasicOperations(ProvisioningMQTTTranslationStageTestBase): @pytest.mark.it("Runs an operation on the next stage") - def test_runs_publish(self, params, mock_stage, stages_configured, op): - mock_stage.run_op(op) - new_op = mock_stage.next._execute_op.call_args[0][0] + def test_runs_publish(self, params, stage, stages_configured, op): + stage.run_op(op) + new_op = stage.next._run_op.call_args[0][0] assert isinstance(new_op, params["new_op_class"]) - @pytest.mark.it("Calls the original op callback with error if the new_op raises an Exception") + @pytest.mark.it("Completes the original op with error if the new_op raises an Exception") def test_new_op_raises_exception( - self, params, mocker, mock_stage, stages_configured, op, some_exception + self, params, mocker, stage, stages_configured, op, arbitrary_exception ): - mock_stage.next._execute_op = mocker.Mock(side_effect=some_exception) - mock_stage.run_op(op) - assert_callback_failed(op=op, error=some_exception) + stage.next._run_op = mocker.Mock(side_effect=arbitrary_exception) + stage.run_op(op) + assert op.complete.call_count == 1 + assert op.complete.call_args == mocker.call(error=arbitrary_exception) @pytest.mark.it("Allows any BaseExceptions raised from inside new_op to propagate") def test_new_op_raises_base_exception( - self, params, mocker, mock_stage, stages_configured, op, fake_base_exception + self, params, mocker, stage, stages_configured, op, arbitrary_base_exception ): - mock_stage.next._execute_op = mocker.Mock(side_effect=fake_base_exception) - with pytest.raises(UnhandledException): - mock_stage.run_op(op) + stage.next._run_op = mocker.Mock(side_effect=arbitrary_base_exception) + with pytest.raises(arbitrary_base_exception.__class__) as e_info: + stage.run_op(op) + e_info.value is arbitrary_base_exception - @pytest.mark.it("Calls the original op callback with no error if the new_op operation succeeds") - def test_returns_success_if_publish_succeeds(self, params, mock_stage, stages_configured, op): - mock_stage.run_op(op) - assert_callback_succeeded(op) + @pytest.mark.it("Completes the original op with no error if the new_op operation succeeds") + def test_returns_success_if_publish_succeeds( + self, mocker, params, stage, stages_configured, op, next_stage_succeeds + ): + stage.run_op(op) + assert op.complete.call_count == 1 + assert op.complete.call_args == mocker.call(error=None) publish_ops = [ { - "name": "send register request", - "op_class": pipeline_ops_provisioning.SendRegistrationRequestOperation, - "op_init_kwargs": {"request_id": fake_request_id, "request_payload": fake_mqtt_payload}, + "name": "send register request with no payload", + "op_class": pipeline_ops_base.RequestOperation, + "op_init_kwargs": { + "request_id": fake_request_id, + "request_type": pipeline_constant.REGISTER, + "method": "PUT", + "resource_location": "/", + "request_body": '{{"payload": {json_payload}, "registrationId": "{reg_id}"}}'.format( + reg_id=fake_registration_id, json_payload=json.dumps(None) + ), + }, "topic": "$dps/registrations/PUT/iotdps-register/?$rid={request_id}".format( request_id=fake_request_id ), - "publish_payload": fake_mqtt_payload, + "publish_payload": '{{"payload": {json_payload}, "registrationId": "{reg_id}"}}'.format( + reg_id=fake_registration_id, json_payload=json.dumps(None) + ), + }, + { + "name": "send register request with payload", + "op_class": pipeline_ops_base.RequestOperation, + "op_init_kwargs": { + "request_id": fake_request_id, + "request_type": pipeline_constant.REGISTER, + "method": "PUT", + "resource_location": "/", + "request_body": '{{"payload": {json_payload}, "registrationId": "{reg_id}"}}'.format( + reg_id=fake_registration_id, json_payload=json.dumps(fake_mqtt_payload) + ), + }, + "topic": "$dps/registrations/PUT/iotdps-register/?$rid={request_id}".format( + request_id=fake_request_id + ), + "publish_payload": '{{"payload": {json_payload}, "registrationId": "{reg_id}"}}'.format( + reg_id=fake_registration_id, json_payload=json.dumps(fake_mqtt_payload) + ), }, { "name": "send query request", - "op_class": pipeline_ops_provisioning.SendQueryRequestOperation, + "op_class": pipeline_ops_base.RequestOperation, "op_init_kwargs": { "request_id": fake_request_id, - "operation_id": fake_operation_id, - "request_payload": fake_mqtt_payload, + "query_params": {"operation_id": fake_operation_id}, + "request_type": pipeline_constant.QUERY, + "method": "GET", + "resource_location": "/", + "request_body": fake_mqtt_payload, }, "topic": "$dps/registrations/GET/iotdps-get-operationstatus/?$rid={request_id}&operationId={operation_id}".format( request_id=fake_request_id, operation_id=fake_operation_id @@ -301,18 +339,18 @@ publish_ops = [ @pytest.mark.parametrize("params", publish_ops, ids=[x["name"] for x in publish_ops]) -@pytest.mark.describe("ProvisioningMQTTConverterStage run_op function for publish operations") -class TestProvisioningMQTTConverterForPublishOps(object): +@pytest.mark.describe("ProvisioningMQTTTranslationStage run_op function for publish operations") +class TestProvisioningMQTTTranslationStageForPublishOps(ProvisioningMQTTTranslationStageTestBase): @pytest.mark.it("Uses correct registration topic string when publishing") - def test_uses_topic_for(self, mock_stage, stages_configured, params, op): - mock_stage.run_op(op) - new_op = mock_stage.next._execute_op.call_args[0][0] + def test_uses_topic_for(self, stage, stages_configured, params, op): + stage.run_op(op) + new_op = stage.next._run_op.call_args[0][0] assert new_op.topic == params["topic"] @pytest.mark.it("Sends correct payload when publishing") - def test_sends_correct_body(self, mock_stage, stages_configured, params, op): - mock_stage.run_op(op) - new_op = mock_stage.next._execute_op.call_args[0][0] + def test_sends_correct_body(self, stage, stages_configured, params, op): + stage.run_op(op) + new_op = stage.next._run_op.call_args[0][0] assert new_op.payload == params["publish_payload"] @@ -328,85 +366,78 @@ sub_unsub_operations = [ ] -@pytest.mark.describe("ProvisioningMQTTConverterStage run_op function with EnableFeature operation") -class TestProvisioningMQTTConverterWithEnable(object): +@pytest.mark.describe( + "ProvisioningMQTTTranslationStage run_op function with EnableFeature operation" +) +class TestProvisioningMQTTTranslationStageWithEnable(ProvisioningMQTTTranslationStageTestBase): @pytest.mark.parametrize( "op_parameters", sub_unsub_operations, ids=[x["op_class"].__name__ for x in sub_unsub_operations], ) @pytest.mark.it("Gets the correct topic") - def test_converts_feature_name_to_topic( - self, mocker, mock_stage, stages_configured, op_parameters - ): + def test_converts_feature_name_to_topic(self, mocker, stage, stages_configured, op_parameters): topic = "$dps/registrations/res/#" - mock_stage.next._execute_op = mocker.Mock() + stage.next._run_op = mocker.Mock() - op = op_parameters["op_class"](feature_name=None) - mock_stage.run_op(op) - new_op = mock_stage.next._execute_op.call_args[0][0] + op = op_parameters["op_class"](feature_name=None, callback=mocker.MagicMock()) + stage.run_op(op) + new_op = stage.next._run_op.call_args[0][0] assert isinstance(new_op, op_parameters["new_op"]) assert new_op.topic == topic -@pytest.fixture -def add_pipeline_root(mock_stage, mocker): - root = pipeline_stages_base.PipelineRootStage() - mocker.spy(root, "handle_pipeline_event") - mock_stage.previous = root - - -@pytest.mark.describe("ProvisioningMQTTConverterStage _handle_pipeline_event") -class TestProvisioningMQTTConverterHandlePipelineEvent(object): +@pytest.mark.describe("ProvisioningMQTTTranslationStage _handle_pipeline_event") +class TestProvisioningMQTTTranslationStageHandlePipelineEvent( + ProvisioningMQTTTranslationStageTestBase +): @pytest.mark.it("Passes up any mqtt messages with topics that aren't matched by this stage") - def test_passes_up_mqtt_message_with_unknown_topic( - self, mock_stage, stages_configured, add_pipeline_root, mocker - ): + def test_passes_up_mqtt_message_with_unknown_topic(self, stage, stages_configured, mocker): event = pipeline_events_mqtt.IncomingMQTTMessageEvent( topic=unmatched_mqtt_topic, payload=fake_mqtt_payload ) - mock_stage.handle_pipeline_event(event) - assert mock_stage.previous.handle_pipeline_event.call_count == 1 - assert mock_stage.previous.handle_pipeline_event.call_args == mocker.call(event) + stage.handle_pipeline_event(event) + assert stage.previous.handle_pipeline_event.call_count == 1 + assert stage.previous.handle_pipeline_event.call_args == mocker.call(event) @pytest.fixture def dps_response_event(): return pipeline_events_mqtt.IncomingMQTTMessageEvent( - topic=fake_response_topic, payload=fake_mqtt_payload.encode("utf-8") + topic=fake_response_topic, payload=fake_mqtt_payload ) -@pytest.mark.describe("ProvisioningMQTTConverterStage _handle_pipeline_event for response") -class TestProvisioningMQTTConverterHandlePipelineEventRegistrationResponse(object): +@pytest.mark.describe("ProvisioningMQTTTranslationStage _handle_pipeline_event for response") +class TestProvisioningMQTTConverterHandlePipelineEventRegistrationResponse( + ProvisioningMQTTTranslationStageTestBase +): @pytest.mark.it( "Converts mqtt message with topic $dps/registrations/res/#/ to registration response event" ) def test_converts_response_topic_to_registration_response_event( - self, mocker, mock_stage, stages_configured, add_pipeline_root, dps_response_event + self, mocker, stage, stages_configured, dps_response_event ): - mock_stage.handle_pipeline_event(dps_response_event) - assert mock_stage.previous.handle_pipeline_event.call_count == 1 - new_event = mock_stage.previous.handle_pipeline_event.call_args[0][0] - assert isinstance(new_event, pipeline_events_provisioning.RegistrationResponseEvent) + stage.handle_pipeline_event(dps_response_event) + assert stage.previous.handle_pipeline_event.call_count == 1 + new_event = stage.previous.handle_pipeline_event.call_args[0][0] + assert isinstance(new_event, pipeline_events_base.ResponseEvent) @pytest.mark.it("Extracts message properties from the mqtt topic for c2d messages") def test_extracts_some_properties_from_topic( - self, mocker, mock_stage, stages_configured, add_pipeline_root, dps_response_event + self, mocker, stage, stages_configured, dps_response_event ): - mock_stage.handle_pipeline_event(dps_response_event) - new_event = mock_stage.previous.handle_pipeline_event.call_args[0][0] + stage.handle_pipeline_event(dps_response_event) + new_event = stage.previous.handle_pipeline_event.call_args[0][0] assert new_event.request_id == fake_request_id - assert new_event.status_code == "200" + assert new_event.status_code == 200 @pytest.mark.it("Passes up other messages") - def test_if_topic_is_not_response( - self, mocker, mock_stage, stages_configured, add_pipeline_root - ): + def test_if_topic_is_not_response(self, mocker, stage, stages_configured): fake_some_other_topic = "devices/{}/messages/devicebound/".format(fake_device_id) event = pipeline_events_mqtt.IncomingMQTTMessageEvent( topic=fake_some_other_topic, payload=fake_mqtt_payload ) - mock_stage.handle_pipeline_event(event) - assert mock_stage.previous.handle_pipeline_event.call_count == 1 - assert mock_stage.previous.handle_pipeline_event.call_args == mocker.call(event) + stage.handle_pipeline_event(event) + assert stage.previous.handle_pipeline_event.call_count == 1 + assert stage.previous.handle_pipeline_event.call_args == mocker.call(event) diff --git a/azure-iot-device/tests/provisioning/pipeline/test_provisioning_pipeline.py b/azure-iot-device/tests/provisioning/pipeline/test_provisioning_pipeline.py index d6c1a49f7..13ee36ccb 100644 --- a/azure-iot-device/tests/provisioning/pipeline/test_provisioning_pipeline.py +++ b/azure-iot-device/tests/provisioning/pipeline/test_provisioning_pipeline.py @@ -7,16 +7,29 @@ import pytest import logging from azure.iot.device.common.models import X509 -from azure.iot.device.common.pipeline import pipeline_stages_base, operation_flow from azure.iot.device.provisioning.security.sk_security_client import SymmetricKeySecurityClient from azure.iot.device.provisioning.security.x509_security_client import X509SecurityClient from azure.iot.device.provisioning.pipeline.provisioning_pipeline import ProvisioningPipeline -from azure.iot.device.provisioning.pipeline import pipeline_ops_provisioning from tests.common.pipeline import helpers +import json +from azure.iot.device.provisioning.pipeline import constant as dps_constants +from azure.iot.device.provisioning.pipeline import ( + pipeline_stages_provisioning, + pipeline_stages_provisioning_mqtt, + pipeline_ops_provisioning, +) +from azure.iot.device.common.pipeline import ( + pipeline_stages_base, + pipeline_stages_mqtt, + pipeline_ops_base, +) logging.basicConfig(level=logging.DEBUG) +pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") + + +feature = dps_constants.REGISTER -send_msg_qos = 1 fake_symmetric_key = "Zm9vYmFy" fake_registration_id = "MyPensieve" @@ -31,6 +44,9 @@ fake_sas_token = "horcrux_token" fake_security_client = "secure_via_muffliato" fake_request_id = "fake_request_1234" fake_mqtt_payload = "hello hogwarts" +fake_register_publish_payload = '{{"payload": {json_payload}, "registrationId": "{reg_id}"}}'.format( + reg_id=fake_registration_id, json_payload=json.dumps(fake_mqtt_payload) +) fake_operation_id = "fake_operation_9876" fake_sub_unsub_topic = "$dps/registrations/res/#" fake_x509_cert_file = "fantastic_beasts" @@ -92,6 +108,18 @@ def input_security_client(params_security_clients): return params_security_clients["client_class"](**params_security_clients["init_kwargs"]) +@pytest.fixture +def pipeline_configuration(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def pipeline(mocker, input_security_client, pipeline_configuration): + pipeline = ProvisioningPipeline(input_security_client, pipeline_configuration) + mocker.patch.object(pipeline._pipeline, "run_op") + return pipeline + + # automatically mock the transport for all tests in this file. @pytest.fixture(autouse=True) def mock_mqtt_transport(mocker): @@ -101,8 +129,10 @@ def mock_mqtt_transport(mocker): @pytest.fixture(scope="function") -def mock_provisioning_pipeline(mocker, input_security_client, mock_mqtt_transport): - provisioning_pipeline = ProvisioningPipeline(input_security_client) +def mock_provisioning_pipeline( + mocker, input_security_client, mock_mqtt_transport, pipeline_configuration +): + provisioning_pipeline = ProvisioningPipeline(input_security_client, pipeline_configuration) provisioning_pipeline.on_connected = mocker.MagicMock() provisioning_pipeline.on_disconnected = mocker.MagicMock() provisioning_pipeline.on_message_received = mocker.MagicMock() @@ -114,123 +144,193 @@ def mock_provisioning_pipeline(mocker, input_security_client, mock_mqtt_transpor provisioning_pipeline.disconnect() +@pytest.mark.describe("ProvisioningPipeline - Instantiation") @pytest.mark.parametrize("params_security_clients", different_security_clients) -@pytest.mark.describe("Provisioning pipeline - Initializer") -class TestInit(object): - @pytest.mark.it("Happens correctly with the specific security client") - def test_instantiates_correctly(self, params_security_clients, input_security_client): - provisioning_pipeline = ProvisioningPipeline(input_security_client) - assert provisioning_pipeline._pipeline is not None +class TestProvisioningPipelineInstantiation(object): + @pytest.mark.it("Begins tracking the enabled/disabled status of responses") + def test_features(self, input_security_client, pipeline_configuration): + pipeline = ProvisioningPipeline(input_security_client, pipeline_configuration) + pipeline.responses_enabled[feature] + # No assertion required - if this doesn't raise a KeyError, it is a success - @pytest.mark.it("Calls the correct op to pass the security client args into the pipeline") - def test_passes_security_client_args( - self, mocker, params_security_clients, input_security_client + @pytest.mark.it("Marks responses as disabled") + def test_features_disabled(self, input_security_client, pipeline_configuration): + pipeline = ProvisioningPipeline(input_security_client, pipeline_configuration) + assert not pipeline.responses_enabled[feature] + + @pytest.mark.it("Sets all handlers to an initial value of None") + def test_handlers_set_to_none(self, input_security_client, pipeline_configuration): + pipeline = ProvisioningPipeline(input_security_client, pipeline_configuration) + assert pipeline.on_connected is None + assert pipeline.on_disconnected is None + assert pipeline.on_message_received is None + + @pytest.mark.it("Configures the pipeline to trigger handlers in response to external events") + def test_handlers_configured(self, input_security_client, pipeline_configuration): + pipeline = ProvisioningPipeline(input_security_client, pipeline_configuration) + assert pipeline._pipeline.on_pipeline_event_handler is not None + assert pipeline._pipeline.on_connected_handler is not None + assert pipeline._pipeline.on_disconnected_handler is not None + + @pytest.mark.it("Configures the pipeline with a series of PipelineStages") + def test_pipeline_configuration(self, input_security_client, pipeline_configuration): + pipeline = ProvisioningPipeline(input_security_client, pipeline_configuration) + curr_stage = pipeline._pipeline + + expected_stage_order = [ + pipeline_stages_base.PipelineRootStage, + pipeline_stages_provisioning.UseSecurityClientStage, + pipeline_stages_provisioning.RegistrationStage, + pipeline_stages_provisioning.PollingStatusStage, + pipeline_stages_base.CoordinateRequestAndResponseStage, + pipeline_stages_provisioning_mqtt.ProvisioningMQTTTranslationStage, + pipeline_stages_base.AutoConnectStage, + pipeline_stages_base.ReconnectStage, + pipeline_stages_base.ConnectionLockStage, + pipeline_stages_base.RetryStage, + pipeline_stages_base.OpTimeoutStage, + pipeline_stages_mqtt.MQTTTransportStage, + ] + + # Assert that all PipelineStages are there, and they are in the right order + for i in range(len(expected_stage_order)): + expected_stage = expected_stage_order[i] + assert isinstance(curr_stage, expected_stage) + curr_stage = curr_stage.next + + # Assert there are no more additional stages + assert curr_stage is None + + # TODO: revist these tests after auth revision + # They are too tied to auth types (and there's too much variance in auths to effectively test) + # Ideally ProvisioningPipeline is entirely insulated from any auth differential logic (and module/device distinctions) + # In the meantime, we are using a device auth with connection string to stand in for generic SAS auth + # and device auth with X509 certs to stand in for generic X509 auth + @pytest.mark.it( + "Runs a Set SecurityClient Operation with the provided SecurityClient on the pipeline" + ) + def test_security_client_success( + self, mocker, params_security_clients, input_security_client, pipeline_configuration ): mocker.spy(pipeline_stages_base.PipelineRootStage, "run_op") - provisioning_pipeline = ProvisioningPipeline(input_security_client) + pipeline = ProvisioningPipeline(input_security_client, pipeline_configuration) - op = provisioning_pipeline._pipeline.run_op.call_args[0][1] - assert provisioning_pipeline._pipeline.run_op.call_count == 1 + op = pipeline._pipeline.run_op.call_args[0][1] + assert pipeline._pipeline.run_op.call_count == 1 assert isinstance(op, params_security_clients["set_args_op_class"]) assert op.security_client is input_security_client - @pytest.mark.it("Raises an exception if the pipeline op to set security client args fails") - def test_passes_security_client_args_failure( - self, mocker, params_security_clients, input_security_client, fake_exception + @pytest.mark.it( + "Raises exceptions that occurred in execution upon unsuccessful completion of the Set SecurityClient Operation" + ) + def test_security_client_failure( + self, + mocker, + params_security_clients, + input_security_client, + arbitrary_exception, + pipeline_configuration, ): - old_execute_op = pipeline_stages_base.PipelineRootStage._execute_op + old_run_op = pipeline_stages_base.PipelineRootStage._run_op - def fail_set_auth_provider(self, op): + def fail_set_security_client(self, op): if isinstance(op, params_security_clients["set_args_op_class"]): - op.error = fake_exception - operation_flow.complete_op(stage=self, op=op) + op.complete(error=arbitrary_exception) else: - old_execute_op(self, op) + old_run_op(self, op) mocker.patch.object( pipeline_stages_base.PipelineRootStage, - "_execute_op", - side_effect=fail_set_auth_provider, + "_run_op", + side_effect=fail_set_security_client, + autospec=True, ) - with pytest.raises(fake_exception.__class__): - ProvisioningPipeline(input_security_client) + # auth_provider = SymmetricKeyAuthenticationProvider.parse(device_connection_string) + with pytest.raises(arbitrary_exception.__class__) as e_info: + ProvisioningPipeline(input_security_client, pipeline_configuration) + assert e_info.value is arbitrary_exception @pytest.mark.parametrize("params_security_clients", different_security_clients) @pytest.mark.describe("Provisioning pipeline - Connect") -class TestConnect(object): - @pytest.mark.it("Calls connect on transport") - def test_connect_calls_connect_on_provider( - self, params_security_clients, mock_provisioning_pipeline, mock_mqtt_transport - ): - mock_provisioning_pipeline.connect() - - assert mock_mqtt_transport.connect.call_count == 1 - - if params_security_clients["client_class"].__name__ == "SymmetricKeySecurityClient": - assert mock_mqtt_transport.connect.call_args[1]["password"] is not None - assert_for_symmetric_key(mock_mqtt_transport.connect.call_args[1]["password"]) - elif params_security_clients["client_class"].__name__ == "X509SecurityClient": - assert mock_mqtt_transport.connect.call_args[1]["password"] is None - - mock_mqtt_transport.on_mqtt_connected_handler() - mock_provisioning_pipeline.wait_for_on_connected_to_be_called() - - @pytest.mark.it("After complete calls handler with new state") - def test_connected_state_handler_called_wth_new_state_once_provider_gets_connected( - self, mock_provisioning_pipeline, mock_mqtt_transport - ): - mock_provisioning_pipeline.connect() - mock_mqtt_transport.on_mqtt_connected_handler() - mock_provisioning_pipeline.wait_for_on_connected_to_be_called() - - mock_provisioning_pipeline.on_connected.assert_called_once_with("connected") - - @pytest.mark.it("Is ignored if waiting for completion of previous one") - def test_connect_ignored_if_waiting_for_connect_complete( - self, mock_provisioning_pipeline, params_security_clients, mock_mqtt_transport - ): - mock_provisioning_pipeline.connect() - mock_provisioning_pipeline.connect() - mock_mqtt_transport.on_mqtt_connected_handler() - mock_provisioning_pipeline.wait_for_on_connected_to_be_called() - - assert mock_mqtt_transport.connect.call_count == 1 - - if params_security_clients["client_class"].__name__ == "SymmetricKeySecurityClient": - assert mock_mqtt_transport.connect.call_args[1]["password"] is not None - assert_for_symmetric_key(mock_mqtt_transport.connect.call_args[1]["password"]) - elif params_security_clients["client_class"].__name__ == "X509SecurityClient": - assert mock_mqtt_transport.connect.call_args[1]["password"] is None - - mock_provisioning_pipeline.on_connected.assert_called_once_with("connected") - - @pytest.mark.it("Is ignored if waiting for completion of send") - def test_connect_ignored_if_waiting_for_send_complete( - self, mock_provisioning_pipeline, mock_mqtt_transport - ): - mock_provisioning_pipeline.connect() - mock_mqtt_transport.on_mqtt_connected_handler() - mock_provisioning_pipeline.wait_for_on_connected_to_be_called() - - mock_mqtt_transport.reset_mock() - mock_provisioning_pipeline.on_connected.reset_mock() - - mock_provisioning_pipeline.send_request( - request_id=fake_request_id, request_payload=fake_mqtt_payload +class TestProvisioningPipelineConnect(object): + @pytest.mark.it("Runs a ConnectOperation on the pipeline") + def test_runs_op(self, pipeline, mocker): + cb = mocker.MagicMock() + pipeline.connect(callback=cb) + assert pipeline._pipeline.run_op.call_count == 1 + assert isinstance( + pipeline._pipeline.run_op.call_args[0][0], pipeline_ops_base.ConnectOperation ) - mock_provisioning_pipeline.connect() - mock_mqtt_transport.connect.assert_not_called() - mock_provisioning_pipeline.wait_for_on_connected_to_not_be_called() - mock_provisioning_pipeline.on_connected.assert_not_called() + @pytest.mark.it("Triggers the callback upon successful completion of the ConnectOperation") + def test_op_success_with_callback(self, mocker, pipeline): + cb = mocker.MagicMock() - mock_mqtt_transport.on_mqtt_published(0) + # Begin operation + pipeline.connect(callback=cb) + assert cb.call_count == 0 - mock_mqtt_transport.connect.assert_not_called() - mock_provisioning_pipeline.wait_for_on_connected_to_not_be_called() - mock_provisioning_pipeline.on_connected.assert_not_called() + # Trigger op completion + op = pipeline._pipeline.run_op.call_args[0][0] + op.complete(error=None) + + assert cb.call_count == 1 + assert cb.call_args == mocker.call(error=None) + + @pytest.mark.it( + "Calls the callback with the error upon unsuccessful completion of the ConnectOperation" + ) + def test_op_fail(self, mocker, pipeline, arbitrary_exception): + cb = mocker.MagicMock() + + pipeline.connect(callback=cb) + op = pipeline._pipeline.run_op.call_args[0][0] + + op.complete(error=arbitrary_exception) + assert cb.call_count == 1 + assert cb.call_args == mocker.call(error=arbitrary_exception) + + +@pytest.mark.parametrize("params_security_clients", different_security_clients) +@pytest.mark.describe("IoTHubPipeline - .disconnect()") +class TestProvisioningPipelineDisconnect(object): + @pytest.mark.it("Runs a DisconnectOperation on the pipeline") + def test_runs_op(self, pipeline, mocker): + pipeline.disconnect(callback=mocker.MagicMock()) + assert pipeline._pipeline.run_op.call_count == 1 + assert isinstance( + pipeline._pipeline.run_op.call_args[0][0], pipeline_ops_base.DisconnectOperation + ) + + @pytest.mark.it("Triggers the callback upon successful completion of the DisconnectOperation") + def test_op_success_with_callback(self, mocker, pipeline): + cb = mocker.MagicMock() + + # Begin operation + pipeline.disconnect(callback=cb) + assert cb.call_count == 0 + + # Trigger op completion callback + op = pipeline._pipeline.run_op.call_args[0][0] + op.complete(error=None) + + assert cb.call_count == 1 + assert cb.call_args == mocker.call(error=None) + + @pytest.mark.it( + "Calls the callback with the error upon unsuccessful completion of the DisconnectOperation" + ) + def test_op_fail(self, mocker, pipeline, arbitrary_exception): + cb = mocker.MagicMock() + pipeline.disconnect(callback=cb) + + op = pipeline._pipeline.run_op.call_args[0][0] + op.complete(error=arbitrary_exception) + + assert cb.call_count == 1 + assert cb.call_args == mocker.call(error=arbitrary_exception) @pytest.mark.parametrize("params_security_clients", different_security_clients) @@ -238,14 +338,17 @@ class TestConnect(object): class TestSendRegister(object): @pytest.mark.it("Request calls publish on provider") def test_send_register_request_calls_publish_on_provider( - self, mock_provisioning_pipeline, params_security_clients, mock_mqtt_transport + self, mocker, mock_provisioning_pipeline, params_security_clients, mock_mqtt_transport ): + mock_init_uuid = mocker.patch( + "azure.iot.device.common.pipeline.pipeline_stages_base.uuid.uuid4" + ) + mock_init_uuid.return_value = fake_request_id + mock_provisioning_pipeline.connect() mock_mqtt_transport.on_mqtt_connected_handler() mock_provisioning_pipeline.wait_for_on_connected_to_be_called() - mock_provisioning_pipeline.send_request( - request_id=fake_request_id, request_payload=fake_mqtt_payload - ) + mock_provisioning_pipeline.register(payload=fake_mqtt_payload) assert mock_mqtt_transport.connect.call_count == 1 @@ -262,16 +365,19 @@ class TestSendRegister(object): mock_mqtt_transport.wait_for_publish_to_be_called() assert mock_mqtt_transport.publish.call_count == 1 assert mock_mqtt_transport.publish.call_args[1]["topic"] == fake_publish_topic - assert mock_mqtt_transport.publish.call_args[1]["payload"] == fake_mqtt_payload + assert mock_mqtt_transport.publish.call_args[1]["payload"] == fake_register_publish_payload @pytest.mark.it("Request queues and connects before calling publish on provider") def test_send_request_queues_and_connects_before_sending( - self, mock_provisioning_pipeline, params_security_clients, mock_mqtt_transport + self, mocker, mock_provisioning_pipeline, params_security_clients, mock_mqtt_transport ): - # send an event - mock_provisioning_pipeline.send_request( - request_id=fake_request_id, request_payload=fake_mqtt_payload + + mock_init_uuid = mocker.patch( + "azure.iot.device.common.pipeline.pipeline_stages_base.uuid.uuid4" ) + mock_init_uuid.return_value = fake_request_id + # send an event + mock_provisioning_pipeline.register(payload=fake_mqtt_payload) # verify that we called connect assert mock_mqtt_transport.connect.call_count == 1 @@ -302,12 +408,17 @@ class TestSendRegister(object): mock_mqtt_transport.wait_for_publish_to_be_called() assert mock_mqtt_transport.publish.call_count == 1 assert mock_mqtt_transport.publish.call_args[1]["topic"] == fake_publish_topic - assert mock_mqtt_transport.publish.call_args[1]["payload"] == fake_mqtt_payload + assert mock_mqtt_transport.publish.call_args[1]["payload"] == fake_register_publish_payload @pytest.mark.it("Request queues and waits for connect to be completed") def test_send_request_queues_if_waiting_for_connect_complete( - self, mock_provisioning_pipeline, params_security_clients, mock_mqtt_transport + self, mocker, mock_provisioning_pipeline, params_security_clients, mock_mqtt_transport ): + mock_init_uuid = mocker.patch( + "azure.iot.device.common.pipeline.pipeline_stages_base.uuid.uuid4" + ) + mock_init_uuid.return_value = fake_request_id + # start connecting and verify that we've called into the transport mock_provisioning_pipeline.connect() assert mock_mqtt_transport.connect.call_count == 1 @@ -319,9 +430,7 @@ class TestSendRegister(object): assert mock_mqtt_transport.connect.call_args[1]["password"] is None # send an event - mock_provisioning_pipeline.send_request( - request_id=fake_request_id, request_payload=fake_mqtt_payload - ) + mock_provisioning_pipeline.register(payload=fake_mqtt_payload) # verify that we're not connected yet and verify that we havent't published yet mock_provisioning_pipeline.wait_for_on_connected_to_not_be_called() @@ -341,15 +450,19 @@ class TestSendRegister(object): mock_mqtt_transport.wait_for_publish_to_be_called() assert mock_mqtt_transport.publish.call_count == 1 assert mock_mqtt_transport.publish.call_args[1]["topic"] == fake_publish_topic - assert mock_mqtt_transport.publish.call_args[1]["payload"] == fake_mqtt_payload + assert mock_mqtt_transport.publish.call_args[1]["payload"] == fake_register_publish_payload @pytest.mark.it("Request can be sent multiple times overlapping each other") def test_send_request_sends_overlapped_events( self, mock_provisioning_pipeline, mock_mqtt_transport, mocker ): + mock_init_uuid = mocker.patch( + "azure.iot.device.common.pipeline.pipeline_stages_base.uuid.uuid4" + ) + mock_init_uuid.return_value = fake_request_id + fake_request_id_1 = fake_request_id fake_msg_1 = fake_mqtt_payload - fake_request_id_2 = "request_4567" fake_msg_2 = "Petrificus Totalus" # connect @@ -359,9 +472,7 @@ class TestSendRegister(object): # send an event callback_1 = mocker.MagicMock() - mock_provisioning_pipeline.send_request( - request_id=fake_request_id_1, request_payload=fake_msg_1, callback=callback_1 - ) + mock_provisioning_pipeline.register(payload=fake_msg_1, callback=callback_1) fake_publish_topic = "$dps/registrations/PUT/iotdps-register/?$rid={}".format( fake_request_id_1 @@ -369,14 +480,12 @@ class TestSendRegister(object): mock_mqtt_transport.wait_for_publish_to_be_called() assert mock_mqtt_transport.publish.call_count == 1 assert mock_mqtt_transport.publish.call_args[1]["topic"] == fake_publish_topic - assert mock_mqtt_transport.publish.call_args[1]["payload"] == fake_msg_1 + assert mock_mqtt_transport.publish.call_args[1]["payload"] == fake_register_publish_payload # while we're waiting for that send to complete, send another event callback_2 = mocker.MagicMock() # provisioning_pipeline.send_message(fake_msg_2, callback_2) - mock_provisioning_pipeline.send_request( - request_id=fake_request_id_2, request_payload=fake_msg_2, callback=callback_2 - ) + mock_provisioning_pipeline.register(payload=fake_msg_2, callback=callback_2) # verify that we've called publish twice and verify that neither send_message # has completed (because we didn't do anything here to complete it). @@ -393,9 +502,7 @@ class TestSendRegister(object): mock_provisioning_pipeline.wait_for_on_connected_to_be_called() # send an event - mock_provisioning_pipeline.send_request( - request_id=fake_request_id, request_payload=fake_mqtt_payload - ) + mock_provisioning_pipeline.register(payload=fake_mqtt_payload) mock_mqtt_transport.on_mqtt_published(0) # disconnect @@ -403,40 +510,6 @@ class TestSendRegister(object): mock_mqtt_transport.disconnect.assert_called_once_with() -@pytest.mark.parametrize("params_security_clients", different_security_clients) -@pytest.mark.describe("Provisioning pipeline - Send Query") -class TestSendQuery(object): - @pytest.mark.it("Request calls publish on provider") - def test_send_query_calls_publish_on_provider( - self, mock_provisioning_pipeline, params_security_clients, mock_mqtt_transport - ): - mock_provisioning_pipeline.connect() - mock_mqtt_transport.on_mqtt_connected_handler() - mock_provisioning_pipeline.wait_for_on_connected_to_be_called() - mock_provisioning_pipeline.send_request( - request_id=fake_request_id, - request_payload=fake_mqtt_payload, - operation_id=fake_operation_id, - ) - - assert mock_mqtt_transport.connect.call_count == 1 - - if params_security_clients["client_class"].__name__ == "SymmetricKeySecurityClient": - assert mock_mqtt_transport.connect.call_args[1]["password"] is not None - assert_for_symmetric_key(mock_mqtt_transport.connect.call_args[1]["password"]) - elif params_security_clients["client_class"].__name__ == "X509SecurityClient": - assert mock_mqtt_transport.connect.call_args[1]["password"] is None - - fake_publish_topic = "$dps/registrations/GET/iotdps-get-operationstatus/?$rid={}&operationId={}".format( - fake_request_id, fake_operation_id - ) - - mock_mqtt_transport.wait_for_publish_to_be_called() - assert mock_mqtt_transport.publish.call_count == 1 - assert mock_mqtt_transport.publish.call_args[1]["topic"] == fake_publish_topic - assert mock_mqtt_transport.publish.call_args[1]["payload"] == fake_mqtt_payload - - @pytest.mark.parametrize("params_security_clients", different_security_clients) @pytest.mark.describe("Provisioning pipeline - Disconnect") class TestDisconnect(object): @@ -488,19 +561,3 @@ class TestEnable(object): assert mock_mqtt_transport.subscribe.call_count == 1 assert mock_mqtt_transport.subscribe.call_args[1]["topic"] == fake_sub_unsub_topic - - -@pytest.mark.parametrize("params_security_clients", different_security_clients) -@pytest.mark.describe("Provisioning pipeline - Disable") -class TestDisable(object): - @pytest.mark.it("Calls unsubscribe on provider") - def test_unsubscribe_calls_unsubscribe_on_provider( - self, mock_provisioning_pipeline, mock_mqtt_transport - ): - mock_provisioning_pipeline.connect() - mock_mqtt_transport.on_mqtt_connected_handler() - mock_provisioning_pipeline.wait_for_on_connected_to_be_called() - mock_provisioning_pipeline.disable_responses(None) - - assert mock_mqtt_transport.unsubscribe.call_count == 1 - assert mock_mqtt_transport.unsubscribe.call_args[1]["topic"] == fake_sub_unsub_topic diff --git a/azure-iot-device/tests/provisioning/test_provisioning_device_client.py b/azure-iot-device/tests/provisioning/test_provisioning_device_client.py index 1de3860a8..d1968f59a 100644 --- a/azure-iot-device/tests/provisioning/test_provisioning_device_client.py +++ b/azure-iot-device/tests/provisioning/test_provisioning_device_client.py @@ -13,6 +13,14 @@ from azure.iot.device.provisioning.abstract_provisioning_device_client import ( logging.basicConfig(level=logging.DEBUG) +class Wizard(object): + def __init__(self, first_name, last_name, dict_of_stuff): + self.first_name = first_name + self.last_name = last_name + self.props = dict_of_stuff + + +@pytest.mark.it("Init of abstract client raises exception") def test_raises_exception_on_init_of_abstract_client(mocker): fake_pipeline = mocker.MagicMock() with pytest.raises(TypeError): diff --git a/azure-iot-device/tests/provisioning/test_sync_provisioning_device_client.py b/azure-iot-device/tests/provisioning/test_sync_provisioning_device_client.py index ffccfdff0..6776fd009 100644 --- a/azure-iot-device/tests/provisioning/test_sync_provisioning_device_client.py +++ b/azure-iot-device/tests/provisioning/test_sync_provisioning_device_client.py @@ -11,7 +11,10 @@ from azure.iot.device.provisioning.models.registration_result import ( RegistrationResult, RegistrationState, ) -from azure.iot.device.provisioning.pipeline import pipeline_ops_provisioning +from azure.iot.device.provisioning.pipeline import exceptions as pipeline_exceptions +from azure.iot.device.provisioning import security, pipeline +import threading +from azure.iot.device import exceptions as client_exceptions logging.basicConfig(level=logging.DEBUG) @@ -29,143 +32,443 @@ fake_request_id = "request_1234" fake_device_id = "MyNimbus2000" fake_assigned_hub = "Dumbledore'sArmy" -fake_registration_state = RegistrationState(fake_device_id, fake_assigned_hub, fake_sub_status) + +@pytest.fixture +def registration_result(): + registration_state = RegistrationState(fake_device_id, fake_assigned_hub, fake_sub_status) + return RegistrationResult(fake_operation_id, fake_status, registration_state) -def create_success_result(): - return RegistrationResult( - fake_request_id, fake_operation_id, fake_status, fake_registration_state - ) - - -def create_error(): - return RuntimeError("Incoming Failure") - - -def fake_x509(): +@pytest.fixture +def x509(): return X509(fake_x509_cert_file_value, fake_x509_cert_key_file, fake_pass_phrase) -# automatically mock the transport for all tests in this file. @pytest.fixture(autouse=True) -def mock_transport(mocker): - mocker.patch( - "azure.iot.device.common.pipeline.pipeline_stages_mqtt.MQTTTransport", autospec=True +def provisioning_pipeline(mocker): + return mocker.MagicMock(wraps=FakeProvisioningPipeline()) + + +class FakeProvisioningPipeline: + def __init__(self): + self.responses_enabled = {} + + def connect(self, callback): + callback() + + def disconnect(self, callback): + callback() + + def enable_responses(self, callback): + callback() + + def register(self, payload, callback): + callback(result={}) + + +# automatically mock the pipeline for all tests in this file +@pytest.fixture(autouse=True) +def mock_pipeline_init(mocker): + return mocker.patch("azure.iot.device.provisioning.pipeline.ProvisioningPipeline") + + +class SharedClientCreateMethodUserOptionTests(object): + @pytest.mark.it( + "Sets the 'websockets' user option parameter on the PipelineConfig, if provided" ) + def test_websockets_option( + self, mocker, client_create_method, create_method_args, mock_pipeline_init + ): + client_create_method(*create_method_args, websockets=True) + + # Get configuration object + assert mock_pipeline_init.call_count == 1 + config = mock_pipeline_init.call_args[0][1] + + assert config.websockets + + # TODO: Show that input in the wrong format is formatted to the correct one. This test exists + # in the ProvisioningPipelineConfig object already, but we do not currently show that this is felt + # from the API level. + @pytest.mark.it("Sets the 'cipher' user option parameter on the PipelineConfig, if provided") + def test_cipher_option( + self, mocker, client_create_method, create_method_args, mock_pipeline_init + ): + + cipher = "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256" + client_create_method(*create_method_args, cipher=cipher) + + # Get configuration object + assert mock_pipeline_init.call_count == 1 + config = mock_pipeline_init.call_args[0][1] + + assert config.cipher == cipher + + @pytest.mark.it("Raises a TypeError if an invalid user option parameter is provided") + def test_invalid_option( + self, mocker, client_create_method, create_method_args, mock_pipeline_init + ): + with pytest.raises(TypeError): + client_create_method(*create_method_args, invalid_option="some_value") + + @pytest.mark.it("Sets default user options if none are provided") + def test_default_options( + self, mocker, client_create_method, create_method_args, mock_pipeline_init + ): + client_create_method(*create_method_args) + + # Get configuration object + assert mock_pipeline_init.call_count == 1 + config = mock_pipeline_init.call_args[0][1] + + assert not config.websockets + assert not config.cipher -@pytest.mark.describe("ProvisioningDeviceClient - Init") -class TestClientCreate(object): - xfail_notimplemented = pytest.mark.xfail(raises=NotImplementedError, reason="Unimplemented") - - @pytest.mark.it("Is created from a symmetric key and protocol") - @pytest.mark.parametrize( - "protocol", - [ - pytest.param("mqtt", id="mqtt"), - pytest.param(None, id="optional protocol"), - pytest.param("amqp", id="amqp", marks=xfail_notimplemented), - pytest.param("http", id="http", marks=xfail_notimplemented), - ], +@pytest.mark.describe("ProvisioningDeviceClient - Instantiation") +class TestClientInstantiation(object): + @pytest.mark.it( + "Stores the ProvisioningPipeline from the 'provisioning_pipeline' parameter in the '_provisioning_pipeline' attribute" ) - def test_create_from_symmetric_key(self, mocker, protocol): + def test_sets_provisioning_pipeline(self, provisioning_pipeline): + client = ProvisioningDeviceClient(provisioning_pipeline) + + assert client._provisioning_pipeline is provisioning_pipeline + + @pytest.mark.it( + "Instantiates with the initial value of the '_provisioning_payload' attribute set to None" + ) + def test_payload(self, provisioning_pipeline): + client = ProvisioningDeviceClient(provisioning_pipeline) + + assert client._provisioning_payload is None + + +@pytest.mark.describe("ProvisioningDeviceClient - .create_from_symmetric_key()") +class TestClientCreateFromSymmetricKey(SharedClientCreateMethodUserOptionTests): + @pytest.fixture + def client_create_method(self): + return ProvisioningDeviceClient.create_from_symmetric_key + + @pytest.fixture + def create_method_args(self): + return [fake_provisioning_host, fake_registration_id, fake_id_scope, fake_symmetric_key] + + @pytest.mark.it("Creates a SymmetricKeySecurityClient using the given parameters") + def test_security_client(self, mocker): + spy_sec_client = mocker.spy(security, "SymmetricKeySecurityClient") + + ProvisioningDeviceClient.create_from_symmetric_key( + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + symmetric_key=fake_symmetric_key, + ) + + assert spy_sec_client.call_count == 1 + assert spy_sec_client.call_args == mocker.call( + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + symmetric_key=fake_symmetric_key, + ) + + @pytest.mark.it( + "Uses the SymmetricKeySecurityClient object and the ProvisioningPipelineConfig object to create a ProvisioningPipeline" + ) + def test_pipeline(self, mocker, mock_pipeline_init): + # Note that the details of how the pipeline config is set up are covered in the + # SharedClientCreateMethodUserOptionTests + mock_pipeline_config = mocker.patch.object( + pipeline, "ProvisioningPipelineConfig" + ).return_value + mock_sec_client = mocker.patch.object(security, "SymmetricKeySecurityClient").return_value + + ProvisioningDeviceClient.create_from_symmetric_key( + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + symmetric_key=fake_symmetric_key, + ) + + assert mock_pipeline_init.call_count == 1 + assert mock_pipeline_init.call_args == mocker.call(mock_sec_client, mock_pipeline_config) + + @pytest.mark.it("Uses the ProvisioningPipeline to instantiate the client") + def test_client_creation(self, mocker, mock_pipeline_init): + spy_client_init = mocker.spy(ProvisioningDeviceClient, "__init__") + + ProvisioningDeviceClient.create_from_symmetric_key( + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + symmetric_key=fake_symmetric_key, + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(mocker.ANY, mock_pipeline_init.return_value) + + @pytest.mark.it("Returns the instantiated client") + def test_returns_client(self, mocker): client = ProvisioningDeviceClient.create_from_symmetric_key( - fake_provisioning_host, fake_symmetric_key, fake_registration_id, fake_id_scope + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + symmetric_key=fake_symmetric_key, ) assert isinstance(client, ProvisioningDeviceClient) - assert client._provisioning_pipeline is not None - @pytest.mark.it("Is created from a x509 certificate key and protocol") + +@pytest.mark.describe("ProvisioningDeviceClient - .create_from_x509_certificate()") +class TestClientCreateFromX509Certificate(SharedClientCreateMethodUserOptionTests): + @pytest.fixture + def client_create_method(self): + return ProvisioningDeviceClient.create_from_x509_certificate + + @pytest.fixture + def create_method_args(self, x509): + return [fake_provisioning_host, fake_registration_id, fake_id_scope, x509] + + @pytest.mark.it("Creates an X509SecurityClient using the given parameters") + def test_security_client(self, mocker, x509): + spy_sec_client = mocker.spy(security, "X509SecurityClient") + + ProvisioningDeviceClient.create_from_x509_certificate( + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + x509=x509, + ) + + assert spy_sec_client.call_count == 1 + assert spy_sec_client.call_args == mocker.call( + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + x509=x509, + ) + + @pytest.mark.it( + "Uses the X509SecurityClient object and the ProvisioningPipelineConfig object to create a ProvisioningPipeline" + ) + def test_pipeline(self, mocker, mock_pipeline_init, x509): + # Note that the details of how the pipeline config is set up are covered in the + # SharedClientCreateMethodUserOptionTests + mock_pipeline_config = mocker.patch.object( + pipeline, "ProvisioningPipelineConfig" + ).return_value + mock_sec_client = mocker.patch.object(security, "X509SecurityClient").return_value + + ProvisioningDeviceClient.create_from_x509_certificate( + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + x509=x509, + ) + + assert mock_pipeline_init.call_count == 1 + assert mock_pipeline_init.call_args == mocker.call(mock_sec_client, mock_pipeline_config) + + @pytest.mark.it("Uses the ProvisioningPipeline to instantiate the client") + def test_client_creation(self, mocker, mock_pipeline_init, x509): + spy_client_init = mocker.spy(ProvisioningDeviceClient, "__init__") + + ProvisioningDeviceClient.create_from_x509_certificate( + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + x509=x509, + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(mocker.ANY, mock_pipeline_init.return_value) + + @pytest.mark.it("Returns the instantiated client") + def test_returns_client(self, mocker, x509): + client = ProvisioningDeviceClient.create_from_x509_certificate( + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + x509=x509, + ) + assert isinstance(client, ProvisioningDeviceClient) + + +@pytest.mark.describe("ProvisioningDeviceClient - .register()") +class TestClientRegister(object): + @pytest.mark.it("Implicitly enables responses from provisioning service if not already enabled") + def test_enables_provisioning_only_if_not_already_enabled( + self, mocker, provisioning_pipeline, registration_result + ): + # Override callback to pass successful result + def register_complete_success_callback(payload, callback): + callback(result=registration_result) + + mocker.patch.object( + provisioning_pipeline, "register", side_effect=register_complete_success_callback + ) + + provisioning_pipeline.responses_enabled.__getitem__.return_value = False + + # assert provisioning_pipeline.responses_enabled is False + client = ProvisioningDeviceClient(provisioning_pipeline) + client.register() + + assert provisioning_pipeline.enable_responses.call_count == 1 + + provisioning_pipeline.enable_responses.reset_mock() + + provisioning_pipeline.responses_enabled.__getitem__.return_value = True + client.register() + assert provisioning_pipeline.enable_responses.call_count == 0 + + @pytest.mark.it("Begins a 'register' pipeline operation") + def test_register_calls_pipeline_register( + self, provisioning_pipeline, mocker, registration_result + ): + def register_complete_success_callback(payload, callback): + callback(result=registration_result) + + mocker.patch.object( + provisioning_pipeline, "register", side_effect=register_complete_success_callback + ) + client = ProvisioningDeviceClient(provisioning_pipeline) + client.register() + assert provisioning_pipeline.register.call_count == 1 + + @pytest.mark.it( + "Waits for the completion of the 'register' pipeline operation before returning" + ) + def test_waits_for_pipeline_op_completion(self, mocker, registration_result): + manual_provisioning_pipeline_with_callback = mocker.MagicMock() + event_init_mock = mocker.patch.object(threading, "Event") + event_mock = event_init_mock.return_value + pipeline_function = manual_provisioning_pipeline_with_callback.register + + def check_callback_completes_event(): + # Assert exactly one Event was instantiated so we know the following asserts + # are related to the code under test ONLY + assert event_init_mock.call_count == 1 + + # Assert waiting for Event to complete + assert event_mock.wait.call_count == 1 + assert event_mock.set.call_count == 0 + + # Manually trigger callback + cb = pipeline_function.call_args[1]["callback"] + cb(result=registration_result) + + # Assert Event is now completed + assert event_mock.set.call_count == 1 + + event_mock.wait.side_effect = check_callback_completes_event + + client = ProvisioningDeviceClient(manual_provisioning_pipeline_with_callback) + client._provisioning_payload = "payload" + client.register() + + @pytest.mark.it("Returns the registration result that the pipeline returned") + def test_verifies_registration_result_returned( + self, mocker, provisioning_pipeline, registration_result + ): + result = registration_result + + def register_complete_success_callback(payload, callback): + callback(result=result) + + mocker.patch.object( + provisioning_pipeline, "register", side_effect=register_complete_success_callback + ) + + client = ProvisioningDeviceClient(provisioning_pipeline) + result_returned = client.register() + assert result_returned == result + + @pytest.mark.it( + "Raises a client error if the `register` pipeline operation calls back with a pipeline error" + ) @pytest.mark.parametrize( - "protocol", + "pipeline_error,client_error", [ - pytest.param("mqtt", id="mqtt"), - pytest.param(None, id="optional protocol"), - pytest.param("amqp", id="amqp", marks=xfail_notimplemented), - pytest.param("http", id="http", marks=xfail_notimplemented), + pytest.param( + pipeline_exceptions.ConnectionDroppedError, + client_exceptions.ConnectionDroppedError, + id="ConnectionDroppedError->ConnectionDroppedError", + ), + pytest.param( + pipeline_exceptions.ConnectionFailedError, + client_exceptions.ConnectionFailedError, + id="ConnectionFailedError->ConnectionFailedError", + ), + pytest.param( + pipeline_exceptions.UnauthorizedError, + client_exceptions.CredentialError, + id="UnauthorizedError->CredentialError", + ), + pytest.param( + pipeline_exceptions.ProtocolClientError, + client_exceptions.ClientError, + id="ProtocolClientError->ClientError", + ), + pytest.param(Exception, client_exceptions.ClientError, id="Exception->ClientError"), ], ) - def test_create_from_x509_cert(self, mocker, protocol): - client = ProvisioningDeviceClient.create_from_x509_certificate( - fake_provisioning_host, fake_registration_id, fake_id_scope, fake_x509() - ) - assert isinstance(client, ProvisioningDeviceClient) - assert client._provisioning_pipeline is not None - - -@pytest.mark.describe("ProvisioningDeviceClient") -class TestClientRegister(object): - @pytest.mark.it( - "Register calls register on polling machine with passed in callback and returns the registration result" - ) - def test_client_register_success_calls_polling_machine_register_with_callback( - self, mocker, mock_polling_machine + def test_raises_error_on_pipeline_op_error( + self, mocker, pipeline_error, client_error, provisioning_pipeline ): - # Override callback to pass successful result - def register_complete_success_callback(callback): - callback(result=create_success_result()) + error = pipeline_error() + + def register_complete_failure_callback(payload, callback): + callback(result=None, error=error) mocker.patch.object( - mock_polling_machine, "register", side_effect=register_complete_success_callback + provisioning_pipeline, "register", side_effect=register_complete_failure_callback ) - mqtt_provisioning_pipeline = mocker.MagicMock() - mock_polling_machine_init = mocker.patch( - "azure.iot.device.provisioning.provisioning_device_client.PollingMachine" - ) - mock_polling_machine_init.return_value = mock_polling_machine - - client = ProvisioningDeviceClient(mqtt_provisioning_pipeline) - result = client.register() - - assert mock_polling_machine.register.call_count == 1 - assert callable(mock_polling_machine.register.call_args[1]["callback"]) - assert result is not None - assert result.registration_state == fake_registration_state - assert result.status == fake_status - assert result.registration_state == fake_registration_state - assert result.registration_state.device_id == fake_device_id - assert result.registration_state.assigned_hub == fake_assigned_hub - - @pytest.mark.it( - "Register calls register on polling machine with passed in callback and raises the error when an error has occured" - ) - def test_client_register_failure_calls_polling_machine_register_with_callback( - self, mocker, mock_polling_machine - ): - # Override callback to pass successful result - def register_complete_failure_callback(callback): - callback(result=None, error=create_error()) - - mocker.patch.object( - mock_polling_machine, "register", side_effect=register_complete_failure_callback - ) - - mqtt_provisioning_pipeline = mocker.MagicMock() - mock_polling_machine_init = mocker.patch( - "azure.iot.device.provisioning.provisioning_device_client.PollingMachine" - ) - mock_polling_machine_init.return_value = mock_polling_machine - - client = ProvisioningDeviceClient(mqtt_provisioning_pipeline) - with pytest.raises(RuntimeError): + client = ProvisioningDeviceClient(provisioning_pipeline) + with pytest.raises(client_error) as e_info: client.register() - assert mock_polling_machine.register.call_count == 1 - assert callable(mock_polling_machine.register.call_args[1]["callback"]) + assert e_info.value.__cause__ is error + assert provisioning_pipeline.register.call_count == 1 - @pytest.mark.it("Cancel calls cancel on polling machine with passed in callback") - def test_client_cancel_calls_polling_machine_cancel_with_callback( - self, mocker, mock_polling_machine - ): - mqtt_provisioning_pipeline = mocker.MagicMock() - mock_polling_machine_init = mocker.patch( - "azure.iot.device.provisioning.provisioning_device_client.PollingMachine" - ) - mock_polling_machine_init.return_value = mock_polling_machine - client = ProvisioningDeviceClient(mqtt_provisioning_pipeline) - client.cancel() +@pytest.mark.describe("ProvisioningDeviceClient - .set_provisioning_payload()") +class TestClientProvisioningPayload(object): + @pytest.mark.it("Sets the payload on the provisioning payload attribute") + @pytest.mark.parametrize( + "payload_input", + [ + pytest.param("Hello Hogwarts", id="String input"), + pytest.param(222, id="Integer input"), + pytest.param(object(), id="Object input"), + pytest.param(None, id="None input"), + pytest.param([1, "str"], id="List input"), + pytest.param({"a": 2}, id="Dictionary input"), + ], + ) + def test_set_payload(self, mocker, payload_input): + provisioning_pipeline = mocker.MagicMock() - assert mock_polling_machine.cancel.call_count == 1 - assert callable(mock_polling_machine.cancel.call_args[1]["callback"]) + client = ProvisioningDeviceClient(provisioning_pipeline) + client.provisioning_payload = payload_input + assert client._provisioning_payload == payload_input + + @pytest.mark.it("Gets the payload from the provisioning payload property") + @pytest.mark.parametrize( + "payload_input", + [ + pytest.param("Hello Hogwarts", id="String input"), + pytest.param(222, id="Integer input"), + pytest.param(object(), id="Object input"), + pytest.param(None, id="None input"), + pytest.param([1, "str"], id="List input"), + pytest.param({"a": 2}, id="Dictionary input"), + ], + ) + def test_get_payload(self, mocker, payload_input): + provisioning_pipeline = mocker.MagicMock() + + client = ProvisioningDeviceClient(provisioning_pipeline) + client.provisioning_payload = payload_input + assert client.provisioning_payload == payload_input diff --git a/azure-iot-device/tests/test_product_info.py b/azure-iot-device/tests/test_product_info.py new file mode 100644 index 000000000..e31bfd02a --- /dev/null +++ b/azure-iot-device/tests/test_product_info.py @@ -0,0 +1,69 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import pytest +from azure.iot.device.product_info import ProductInfo +import platform +from azure.iot.device.constant import VERSION, IOTHUB_IDENTIFIER, PROVISIONING_IDENTIFIER + + +check_agent_format = ( + "{identifier}/{version}({python_runtime};{os_type} {os_release};{architecture})" +) + + +@pytest.mark.describe("ProductInfo") +class TestProductInfo(object): + @pytest.mark.it( + "Contains python version, operating system and architecture of the system in the iothub agent string" + ) + def test_get_iothub_user_agent(self): + user_agent = ProductInfo.get_iothub_user_agent() + + assert IOTHUB_IDENTIFIER in user_agent + assert VERSION in user_agent + assert platform.python_version() in user_agent + assert platform.system() in user_agent + assert platform.version() in user_agent + assert platform.machine() in user_agent + + @pytest.mark.it("Checks if the format of the agent string is as expected") + def test_checks_format_iothub_agent(self): + expected_part_agent = check_agent_format.format( + identifier=IOTHUB_IDENTIFIER, + version=VERSION, + python_runtime=platform.python_version(), + os_type=platform.system(), + os_release=platform.version(), + architecture=platform.machine(), + ) + user_agent = ProductInfo.get_iothub_user_agent() + assert expected_part_agent in user_agent + + @pytest.mark.it( + "Contains python version, operating system and architecture of the system in the provisioning agent string" + ) + def test_get_provisioning_user_agent(self): + user_agent = ProductInfo.get_provisioning_user_agent() + + assert PROVISIONING_IDENTIFIER in user_agent + assert VERSION in user_agent + assert platform.python_version() in user_agent + assert platform.system() in user_agent + assert platform.version() in user_agent + assert platform.machine() in user_agent + + @pytest.mark.it("Checks if the format of the agent string is as expected") + def test_checks_format_provisioning_agent(self): + expected_part_agent = check_agent_format.format( + identifier=PROVISIONING_IDENTIFIER, + version=VERSION, + python_runtime=platform.python_version(), + os_type=platform.system(), + os_release=platform.version(), + architecture=platform.machine(), + ) + user_agent = ProductInfo.get_provisioning_user_agent() + assert expected_part_agent in user_agent diff --git a/azure-iot-hub/.bumpverion.cfg b/azure-iot-hub/.bumpverion.cfg new file mode 100644 index 000000000..38d73fe88 --- /dev/null +++ b/azure-iot-hub/.bumpverion.cfg @@ -0,0 +1,7 @@ +[bumpversion] +current_version = 2.1.0 +parse = (?P\d+)\.(?P\d+)\.(?P\d+) +serialize = {major}.{minor}.{patch} + +[bumpversion:part] + diff --git a/azure-iot-hub/README.md b/azure-iot-hub/README.md index e5e2df0c2..5c0a09550 100644 --- a/azure-iot-hub/README.md +++ b/azure-iot-hub/README.md @@ -2,8 +2,6 @@ The Azure IoTHub Service SDK for Python provides functionality for communicating with the Azure IoT Hub. -**Note that this SDK is currently in preview, and is subject to change.** - ## Features The SDK provides the following clients: @@ -13,124 +11,13 @@ The SDK provides the following clients: * Provides CRUD operations for device on IoTHub * Get statistics about the IoTHub service and devices -* ### Digital Twin Service Client - - * Read and update Digital Twin - * Read Digital Twin Interface Instances - * Read Model - -These clients are available with an asynchronous API, as well as a blocking synchronous API for compatibility scenarios. **We recommend you use Python 3.7+ and the asynchronous API.** - -| Python Version | Synchronous API | -| -------------- | --------------- | -| Python 3.5.3+ | **YES** | -| Python 3.4 | **YES** | -| Python 2.7 | **YES** | - ## Installation ```python pip install azure-iot-hub ``` -## Set up an IoT Hub - -1. Install the [Azure CLI](https://docs.microsoft.com/en-us/cli/azure/install-azure-cli?view=azure-cli-latest) (or use the [Azure Cloud Shell](https://shell.azure.com/)) and use it to [create an Azure IoT Hub](https://docs.microsoft.com/en-us/cli/azure/iot/hub?view=azure-cli-latest#az-iot-hub-create). - -```bash -az iot hub create --resource-group --name -``` - -* Note that this operation make take a few minutes. - -## How to use the IoTHub Registry Manager - -* ### Create an IoTHubRegistryManager - -```python -registry_manager = IoTHubRegistryManager(iothub_connection_str) -``` - -* ### Create a device - -```python -new_device = registry_manager.create_device_with_sas(device_id, primary_key, secondary_key, device_state) -``` - -* ### Read device information - -```python -device = registry_manager.get_device(device_id) -``` - -* ### Update device information - -```python -device_updated = registry_manager.update_device_with_sas( - device_id, etag, primary_key, secondary_key, device_state) -``` - -* ### Delete device - -```python -registry_manager.delete_device(device_id) -``` - -* ### Get service statistics - -```python -registry_statistics = registry_manager.get_service_statistics() -``` - -* ### Get device registry statistics - -```python -registry_statistics = registry_manager.get_device_registry_statistics() -``` - -## How to use the Digital Twin Service Client - -* ### Create an DigitalTwinServiceClient - -```python -digital_twin_service_client = DigitalTwinServiceClient(iothub_connection_str) -``` - -* ### Get DigitalTwin of a particular device - -```python -digital_twin = digital_twin_service_client.get_digital_twin(device_id) -``` - -* ### Get a DigitalTwin Interface Instance - -```python -digital_twin_interface_instance = digital_twin_service_client.get_digital_twin_interface_instance( - device_id, interface_instance_name -) -``` - -* ### Update DigitalTwin with a patch - -```python -digital_twin_updated = digital_twin_service_client.update_digital_twin(device_id, patch, etag) -``` - -* ### Update a DigitalTwin property by name - -```python -digital_twin_service_client.update_digital_twin_property( - device_id, interface_instance_name, property_name, property_value -) -``` - -* ### Get a Model - -```python -digital_twin_model = digital_twin_service_client.get_model(model_id) -``` - -## Additional Samples +## IoTHub Samples Check out the [samples repository](https://github.com/Azure/azure-iot-sdk-python/tree/master/azure-iot-hub/samples) for more detailed samples @@ -141,7 +28,4 @@ Our SDK makes use of docstrings which means you can find API documentation direc ```python >>> from azure.iot.hub import IoTHubRegistryManager >>> help(IoTHubRegistryManager) - ->>> from azure.iot.hub import DigitalTwinServiceClient ->>> help(DigitalTwinServiceClient) ``` diff --git a/azure-iot-hub/azure/iot/hub/__init__.py b/azure-iot-hub/azure/iot/hub/__init__.py index d92b872c7..94d607484 100644 --- a/azure-iot-hub/azure/iot/hub/__init__.py +++ b/azure-iot-hub/azure/iot/hub/__init__.py @@ -3,8 +3,14 @@ This library provides service clients and associated models for communicating with Azure IoTHub Services. """ -from .digital_twin_service_client import DigitalTwinServiceClient -from .digital_twin_service_client import DigitalTwin from .iothub_registry_manager import IoTHubRegistryManager +from .iothub_configuration_manager import IoTHubConfigurationManager +from .iothub_job_manager import IoTHubJobManager +from .iothub_http_runtime_manager import IoTHubHttpRuntimeManager -__all__ = ["DigitalTwinServiceClient", "IoTHubRegistryManager", "DigitalTwin"] +__all__ = [ + "IoTHubRegistryManager", + "IoTHubConfigurationManager", + "IoTHubJobManager", + "IoTHubHttpRuntimeManager", +] diff --git a/azure-iot-hub/azure/iot/hub/constant.py b/azure-iot-hub/azure/iot/hub/constant.py index a7830ffe4..303b3e633 100644 --- a/azure-iot-hub/azure/iot/hub/constant.py +++ b/azure-iot-hub/azure/iot/hub/constant.py @@ -6,4 +6,4 @@ """This module defines constants for use across the azure-iot-hub package """ -VERSION = "2.0.0-preview.10" +VERSION = "2.1.1" diff --git a/azure-iot-hub/azure/iot/hub/iothub_amqp_client.py b/azure-iot-hub/azure/iot/hub/iothub_amqp_client.py new file mode 100644 index 000000000..d43a5f4a5 --- /dev/null +++ b/azure-iot-hub/azure/iot/hub/iothub_amqp_client.py @@ -0,0 +1,83 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import logging +import os +import sys +import base64 +import time +import hashlib +import hmac +from uuid import uuid4 +import six.moves.urllib as urllib + +try: + from urllib import quote, quote_plus, urlencode # Py2 +except Exception: + from urllib.parse import quote, quote_plus, urlencode + +import uamqp + +default_sas_expiry = 30 + + +class IoTHubAmqpClient: + def _generate_auth_token(self, uri, sas_name, sas_value): + sas = base64.b64decode(sas_value) + expiry = str(int(time.time() + default_sas_expiry)) + string_to_sign = (uri + "\n" + expiry).encode("utf-8") + signed_hmac_sha256 = hmac.HMAC(sas, string_to_sign, hashlib.sha256) + signature = urllib.parse.quote(base64.b64encode(signed_hmac_sha256.digest())) + return "SharedAccessSignature sr={}&sig={}&se={}&skn={}".format( + uri, signature, expiry, sas_name + ) + + def _build_amqp_endpoint(self, hostname, shared_access_key_name, shared_access_key): + hub_name = hostname.split(".")[0] + endpoint = "{}@sas.root.{}".format(shared_access_key_name, hub_name) + endpoint = quote_plus(endpoint) + sas_token = self._generate_auth_token( + hostname, shared_access_key_name, shared_access_key + "=" + ) + endpoint = endpoint + ":{}@{}".format(quote_plus(sas_token), hostname) + return endpoint + + def __init__(self, hostname, shared_access_key_name, shared_access_key): + self.endpoint = self._build_amqp_endpoint( + hostname, shared_access_key_name, shared_access_key + ) + operation = "/messages/devicebound" + target = "amqps://" + self.endpoint + operation + self.amqp_client = uamqp.SendClient(target) + + def disconnect_sync(self): + """ + Disconnect the Amqp client. + """ + if self.amqp_client: + self.amqp_client.close() + self.amqp_client = None + + def send_message_to_device(self, device_id, message): + """Send a message to the specified deivce. + + :param str device_id: The name (Id) of the device. + :param str message: The message that is to be delivered to the device. + + :raises: Exception if the Send command is not able to send the message + """ + msg_content = message + app_properties = {} + msg_props = uamqp.message.MessageProperties() + msg_props.to = "/devices/{}/messages/devicebound".format(device_id) + msg_props.message_id = str(uuid4()) + message = uamqp.Message( + msg_content, properties=msg_props, application_properties=app_properties + ) + self.amqp_client.queue_message(message) + results = self.amqp_client.send_all_messages(close_on_done=False) + if uamqp.constants.MessageState.SendFailed in results: + raise Exception("C2D message sned failure") diff --git a/azure-iot-hub/azure/iot/hub/iothub_configuration_manager.py b/azure-iot-hub/azure/iot/hub/iothub_configuration_manager.py new file mode 100644 index 000000000..28c2c1548 --- /dev/null +++ b/azure-iot-hub/azure/iot/hub/iothub_configuration_manager.py @@ -0,0 +1,127 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from .auth import ConnectionStringAuthentication +from .protocol.iot_hub_gateway_service_ap_is import IotHubGatewayServiceAPIs as protocol_client +from .protocol.models import Configuration, ConfigurationContent, ConfigurationQueriesTestInput + + +class IoTHubConfigurationManager(object): + """A class to provide convenience APIs for IoTHub Registry Manager operations, + based on top of the auto generated IotHub REST APIs + """ + + def __init__(self, connection_string): + """Initializer for a Configuration Manager Service client. + + After a successful creation the class has been authenticated with IoTHub and + it is ready to call the member APIs to communicate with IoTHub. + + :param str connection_string: The IoTHub connection string used to authenticate connection + with IoTHub. + + :returns: Instance of the IoTHubRegistryManager object. + :rtype: :class:`azure.iot.hub.IoTHubRegistryManager` + """ + + self.auth = ConnectionStringAuthentication(connection_string) + self.protocol = protocol_client(self.auth, "https://" + self.auth["HostName"]) + + def get_configuration(self, configuration_id): + """Retrieves the IoTHub configuration for a particular device. + + :param str configuration_id: The id of the configuration. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: The Configuration object. + """ + return self.protocol.configuration.get(configuration_id) + + def create_configuration(self, configuration): + """Creates a configuration for devices or modules of an IoTHub. + + :param str configuration_id: The id of the configuration. + :param Configuration configuration: The configuration to create. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: Configuration object containing the created configuration. + """ + return self.protocol.configuration.create_or_update(configuration.id, configuration) + + def update_configuration(self, configuration, etag): + """Updates a configuration for devices or modules of an IoTHub. + Note: that configuration Id and Content cannot be updated by the user. + + :param str configuration_id: The id of the configuration. + :param Configuration configuration: The configuration contains the updated configuration. + :param str etag: The etag (if_match) value to use for the update operation. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: Configuration object containing the updated configuration. + """ + return self.protocol.configuration.create_or_update(configuration.id, configuration, etag) + + def delete_configuration(self, configuration_id, etag=None): + """Deletes a configuration from an IoTHub. + + :param str configuration_id: The id of the configuration. + :param Configuration configuration: The configuration to create. + :param str etag: The etag (if_match) value to use for the delete operation. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: Configuration object containing the updated configuration. + """ + if etag is None: + etag = "*" + + return self.protocol.configuration.delete(configuration_id, etag) + + def get_configurations(self, max_count=None): + """Retrieves multiple configurations for device and modules of an IoTHub. + Returns the specified number of configurations. Pagination is not supported. + + :param int max_count: The maximum number of configurations requested. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: The list[Configuration] object. + """ + return self.protocol.configuration.get_configurations(max_count) + + def test_configuration_queries(self, configuration_queries_test_input): + """Validates the target condition query and custom metric queries for a + configuration. + + :param ConfigurationQueriesTestInput configuration_queries_test_input: The queries test input. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: The ConfigurationQueriesTestResponse object. + """ + return self.protocol.configuration.test_queries(configuration_queries_test_input) + + def apply_configuration_on_edge_device(self, device_id, configuration_content): + """Applies the provided configuration content to the specified edge + device. Modules content is mandantory. + + :param ConfigurationContent configuration_content: The name (Id) of the edge device. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: An object. + """ + return self.protocol.configuration.apply_on_edge_device(device_id, configuration_content) diff --git a/azure-iot-hub/azure/iot/hub/iothub_http_runtime_manager.py b/azure-iot-hub/azure/iot/hub/iothub_http_runtime_manager.py new file mode 100644 index 000000000..2a8ca520d --- /dev/null +++ b/azure-iot-hub/azure/iot/hub/iothub_http_runtime_manager.py @@ -0,0 +1,64 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from .auth import ConnectionStringAuthentication +from .protocol.iot_hub_gateway_service_ap_is import IotHubGatewayServiceAPIs as protocol_client + + +class IoTHubHttpRuntimeManager(object): + """A class to provide convenience APIs for IoTHub Http Runtime Manager operations, + based on top of the auto generated IotHub REST APIs + """ + + def __init__(self, connection_string): + """Initializer for a Http Runtime Manager Service client. + + After a successful creation the class has been authenticated with IoTHub and + it is ready to call the member APIs to communicate with IoTHub. + + :param str connection_string: The IoTHub connection string used to authenticate connection + with IoTHub. + + :returns: Instance of the IoTHubHttpRuntimeManager object. + :rtype: :class:`azure.iot.hub.IoTHubHttpRuntimeManager` + """ + + self.auth = ConnectionStringAuthentication(connection_string) + self.protocol = protocol_client(self.auth, "https://" + self.auth["HostName"]) + + def receive_feedback_notification(self): + """This method is used to retrieve feedback of a cloud-to-device message. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: None. + """ + return self.protocol.http_runtime.receive_feedback_notification() + + def complete_feedback_notification(self, lock_token): + """This method completes a feedback message. + + :param lock_token Lock token. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: None. + """ + return self.protocol.http_runtime.complete_feedback_notification(lock_token) + + def abandon_feedback_notification(self, lock_token): + """This method abandons a feedback message. + + :param lock_token Lock token. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: None. + """ + return self.protocol.http_runtime.abandon_feedback_notification(lock_token) diff --git a/azure-iot-hub/azure/iot/hub/iothub_job_manager.py b/azure-iot-hub/azure/iot/hub/iothub_job_manager.py new file mode 100644 index 000000000..83b970718 --- /dev/null +++ b/azure-iot-hub/azure/iot/hub/iothub_job_manager.py @@ -0,0 +1,127 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from .auth import ConnectionStringAuthentication +from .protocol.iot_hub_gateway_service_ap_is import IotHubGatewayServiceAPIs as protocol_client +from .protocol.models import Configuration, ConfigurationContent, ConfigurationQueriesTestInput + + +class IoTHubJobManager(object): + """A class to provide convenience APIs for IoTHub Job Manager operations, + based on top of the auto generated IotHub REST APIs + """ + + def __init__(self, connection_string): + """Initializer for a Job Manager Service client. + + After a successful creation the class has been authenticated with IoTHub and + it is ready to call the member APIs to communicate with IoTHub. + + :param str connection_string: The IoTHub connection string used to authenticate connection + with IoTHub. + + :returns: Instance of the IoTHubJobManager object. + :rtype: :class:`azure.iot.hub.IoTHubJobManager` + """ + + self.auth = ConnectionStringAuthentication(connection_string) + self.protocol = protocol_client(self.auth, "https://" + self.auth["HostName"]) + + def create_import_export_job(self, job_properties): + """Creates a new import/export job on an IoT hub. + + :param job_properties job_properties: Specifies the job specification. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: JobProperties object containing the created job. + """ + return self.protocol.job_client.create_import_export_job(job_properties) + + def get_import_export_jobs(self): + """Retrieves the status of all import/export jobs on an IoTHub. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: The list[JobProperties] object. + """ + return self.protocol.job_client.get_import_export_jobs() + + def get_import_export_job(self, job_id): + """Retrieves the status of an import/export job on an IoTHub. + + :param str job_id: The ID of the job. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: The JobProperties object containing the requested job. + """ + return self.protocol.job_client.get_import_export_job(job_id) + + def cancel_import_export_job(self, job_id): + """Cancels an import/export job on an IoT hub. + + :param str job_id: The ID of the job. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: Object. + """ + return self.protocol.job_client.cancel_import_export_job(job_id) + + def create_job(self, job_id, job_request): + """Creates a new job to schedule update twins or device direct methods on an IoT hub. + + :param str job_id: The ID of the job. + :param job_request job_request: Specifies the job. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: JobResponse object containing the created job. + """ + return self.protocol.job_client.create_job(job_id, job_request) + + def get_job(self, job_id): + """Retrieves the details of a scheduled job on an IoTHub. + + :param str job_id: The ID of the job. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: The JobResponse object containing the requested details. + """ + return self.protocol.job_client.get_job(job_id) + + def cancel_job(self, job_id): + """Cancels a scheduled job on an IoT hub. + + :param str job_id: The ID of the job. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: JobResponse object containing the cancelled job. + """ + return self.protocol.job_client.cancel_job(job_id) + + def query_jobs(self, job_type, job_status): + """Query an IoT hub to retrieve information regarding jobs using the IoT Hub query language. + + :param str job_type: The type of the jobs. + :param str job_status: The status of the jobs. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: QueryResult object containing the jobs. + """ + return self.protocol.job_client.query_jobs(job_type, job_status) diff --git a/azure-iot-hub/azure/iot/hub/iothub_registry_manager.py b/azure-iot-hub/azure/iot/hub/iothub_registry_manager.py index c481c310c..b5311252c 100644 --- a/azure-iot-hub/azure/iot/hub/iothub_registry_manager.py +++ b/azure-iot-hub/azure/iot/hub/iothub_registry_manager.py @@ -3,23 +3,49 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- - +from .iothub_amqp_client import IoTHubAmqpClient as iothub_amqp_client from .auth import ConnectionStringAuthentication -from .protocol.iot_hub_gateway_service_ap_is20190701_preview import ( - IotHubGatewayServiceAPIs20190701Preview as protocol_client, -) +from .protocol.iot_hub_gateway_service_ap_is import IotHubGatewayServiceAPIs as protocol_client from .protocol.models import ( Device, Module, SymmetricKey, X509Thumbprint, AuthenticationMechanism, - Configuration, ServiceStatistics, RegistryStatistics, + QuerySpecification, + Twin, + CloudToDeviceMethod, + CloudToDeviceMethodResult, ) +class QueryResult(object): + """The query result. + :param type: The query result type. Possible values include: 'unknown', + 'twin', 'deviceJob', 'jobResponse', 'raw', 'enrollment', + 'enrollmentGroup', 'deviceRegistration' + :type type: str or ~protocol.models.enum + :param items: The query result items, as a collection. + :type items: list[object] + :param continuation_token: Request continuation token. + :type continuation_token: str + """ + + _attribute_map = { + "type": {"key": "type", "type": "str"}, + "items": {"key": "items", "type": "[object]"}, + "continuation_token": {"key": "continuationToken", "type": "str"}, + } + + def __init__(self, **kwargs): + super(QueryResult, self).__init__(**kwargs) + self.type = kwargs.get("type", None) + self.items = kwargs.get("items", None) + self.continuation_token = kwargs.get("continuation_token", None) + + class IoTHubRegistryManager(object): """A class to provide convenience APIs for IoTHub Registry Manager operations, based on top of the auto generated IotHub REST APIs @@ -31,24 +57,35 @@ class IoTHubRegistryManager(object): After a successful creation the class has been authenticated with IoTHub and it is ready to call the member APIs to communicate with IoTHub. - :param: str connection_string: The authentication information - (IoTHub connection string) to connect to IoTHub. + :param str connection_string: The IoTHub connection string used to authenticate connection + with IoTHub. - :returns: IoTHubRegistryManager object. + :returns: Instance of the IoTHubRegistryManager object. + :rtype: :class:`azure.iot.hub.IoTHubRegistryManager` """ - self.auth = ConnectionStringAuthentication(connection_string) self.protocol = protocol_client(self.auth, "https://" + self.auth["HostName"]) + self.amqp_svc_client = iothub_amqp_client( + self.auth["HostName"], self.auth["SharedAccessKeyName"], self.auth["SharedAccessKey"] + ) + + def __del__(self): + """ + Deinitializer for a Registry Manager Service client. + """ + self.amqp_svc_client.disconnect_sync() def create_device_with_sas(self, device_id, primary_key, secondary_key, status): """Creates a device identity on IoTHub using SAS authentication. - :param str device_id: The name (deviceId) of the device. + :param str device_id: The name (Id) of the device. :param str primary_key: Primary authentication key. :param str secondary_key: Secondary authentication key. - :param str status: Initital state of the created device (enabled or disabled). + :param str status: Initital state of the created device. + (Possible values: "enabled" or "disabled") - :raises: HttpOperationError if the HTTP response status is not in [200]. + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. :returns: Device object containing the created device. """ @@ -61,17 +98,19 @@ class IoTHubRegistryManager(object): } device = Device(**kwargs) - return self.protocol.service.create_or_update_device(device_id, device) + return self.protocol.registry_manager.create_or_update_device(device_id, device) def create_device_with_x509(self, device_id, primary_thumbprint, secondary_thumbprint, status): """Creates a device identity on IoTHub using X509 authentication. - :param str device_id: The name (deviceId) of the device. + :param str device_id: The name (Id) of the device. :param str primary_thumbprint: Primary X509 thumbprint. :param str secondary_thumbprint: Secondary X509 thumbprint. - :param str status: Initital state of the created device (enabled or disabled). + :param str status: Initital state of the created device. + (Possible values: "enabled" or "disabled") - :raises: HttpOperationError if the HTTP response status is not in [200]. + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. :returns: Device object containing the created device. """ @@ -88,15 +127,17 @@ class IoTHubRegistryManager(object): } device = Device(**kwargs) - return self.protocol.service.create_or_update_device(device_id, device) + return self.protocol.registry_manager.create_or_update_device(device_id, device) def create_device_with_certificate_authority(self, device_id, status): """Creates a device identity on IoTHub using certificate authority. - :param str device_id: The name (deviceId) of the device. - :param str status: Initital state of the created device (enabled or disabled). + :param str device_id: The name (Id) of the device. + :param str status: Initial state of the created device. + (Possible values: "enabled" or "disabled"). - :raises: HttpOperationError if the HTTP response status is not in [200]. + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. :returns: Device object containing the created device. """ @@ -107,18 +148,20 @@ class IoTHubRegistryManager(object): } device = Device(**kwargs) - return self.protocol.service.create_or_update_device(device_id, device) + return self.protocol.registry_manager.create_or_update_device(device_id, device) def update_device_with_sas(self, device_id, etag, primary_key, secondary_key, status): """Updates a device identity on IoTHub using SAS authentication. - :param str device_id: The name (deviceId) of the device. + :param str device_id: The name (Id) of the device. :param str etag: The etag (if_match) value to use for the update operation. :param str primary_key: Primary authentication key. :param str secondary_key: Secondary authentication key. - :param str status: Initital state of the created device (enabled or disabled). + :param str status: Initital state of the created device. + (Possible values: "enabled" or "disabled"). - :raises: HttpOperationError if the HTTP response status is not in [200]. + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. :returns: The updated Device object containing the created device. """ @@ -132,20 +175,22 @@ class IoTHubRegistryManager(object): } device = Device(**kwargs) - return self.protocol.service.create_or_update_device(device_id, device, "*") + return self.protocol.registry_manager.create_or_update_device(device_id, device, "*") def update_device_with_x509( self, device_id, etag, primary_thumbprint, secondary_thumbprint, status ): """Updates a device identity on IoTHub using X509 authentication. - :param str device_id: The name (deviceId) of the device. + :param str device_id: The name (Id) of the device. :param str etag: The etag (if_match) value to use for the update operation. :param str primary_thumbprint: Primary X509 thumbprint. :param str secondary_thumbprint: Secondary X509 thumbprint. - :param str status: Initital state of the created device (enabled or disabled). + :param str status: Initital state of the created device. + (Possible values: "enabled" or "disabled"). - :raises: HttpOperationError if the HTTP response status is not in [200]. + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. :returns: The updated Device object containing the created device. """ @@ -163,16 +208,18 @@ class IoTHubRegistryManager(object): } device = Device(**kwargs) - return self.protocol.service.create_or_update_device(device_id, device) + return self.protocol.registry_manager.create_or_update_device(device_id, device) def update_device_with_certificate_authority(self, device_id, etag, status): """Updates a device identity on IoTHub using certificate authority. - :param str device_id: The name (deviceId) of the device. + :param str device_id: The name (Id) of the device. :param str etag: The etag (if_match) value to use for the update operation. - :param str status: Initital state of the created device (enabled or disabled). + :param str status: Initital state of the created device. + (Possible values: "enabled" or "disabled"). - :raises: HttpOperationError if the HTTP response status is not in [200]. + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. :returns: The updated Device object containing the created device. """ @@ -184,75 +231,47 @@ class IoTHubRegistryManager(object): } device = Device(**kwargs) - return self.protocol.service.create_or_update_device(device_id, device) + return self.protocol.registry_manager.create_or_update_device(device_id, device) def get_device(self, device_id): """Retrieves a device identity from IoTHub. - :param str device_id: The name (deviceId) of the device. + :param str device_id: The name (Id) of the device. - :raises: HttpOperationError if the HTTP response status is not in [200]. + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. :returns: The Device object containing the requested device. """ - return self.protocol.service.get_device(device_id) - - def get_configuration(self, device_id): - """Retrieves the IoTHub configuration for a particular device. - - :param str device_id: The name (deviceId) of the device. - - :raises: HttpOperationError if the HTTP response status is not in [200]. - - :returns: The Configuration object. - """ - return self.protocol.service.get_configuration(id) + return self.protocol.registry_manager.get_device(device_id) def delete_device(self, device_id, etag=None): """Deletes a device identity from IoTHub. - :param str device_id: The name (deviceId) of the device. + :param str device_id: The name (Id) of the device. + :param str etag: The etag (if_match) value to use for the delete operation. - :raises: HttpOperationError if the HTTP response status is not in [200]. + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. :returns: None. """ if etag is None: etag = "*" - self.protocol.service.delete_device(device_id, etag) + self.protocol.registry_manager.delete_device(device_id, etag) - def get_service_statistics(self): - """Retrieves the IoTHub service statistics. - - :raises: HttpOperationError if the HTTP response status is not in [200]. - - :returns: The ServiceStatistics object. - """ - return self.protocol.service.get_service_statistics() - - def get_device_registry_statistics(self): - """Retrieves the IoTHub device registry statistics. - - :raises: HttpOperationError if the HTTP response status is not in [200]. - - :returns: The RegistryStatistics object. - """ - return self.protocol.service.get_device_registry_statistics() - - def create_module_with_sas( - self, device_id, module_id, managed_by, primary_key, secondary_key, status - ): + def create_module_with_sas(self, device_id, module_id, managed_by, primary_key, secondary_key): """Creates a module identity for a device on IoTHub using SAS authentication. - :param str device_id: The name (deviceId) of the device. - :param str module_id: The name (moduleID) of the module. + :param str device_id: The name (Id) of the device. + :param str module_id: The name (Id) of the module. :param str managed_by: The name of the manager device (edge). :param str primary_key: Primary authentication key. :param str secondary_key: Secondary authentication key. - :param str status: Initital state of the created device (enabled or disabled). - :raises: HttpOperationError if the HTTP response status is not in [200]. + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. :returns: Module object containing the created module. """ @@ -262,26 +281,25 @@ class IoTHubRegistryManager(object): "device_id": device_id, "module_id": module_id, "managed_by": managed_by, - "status": status, "authentication": AuthenticationMechanism(type="sas", symmetric_key=symmetric_key), } module = Module(**kwargs) - return self.protocol.service.create_or_update_module(device_id, module_id, module) + return self.protocol.registry_manager.create_or_update_module(device_id, module_id, module) def create_module_with_x509( - self, device_id, module_id, managed_by, primary_thumbprint, secondary_thumbprint, status + self, device_id, module_id, managed_by, primary_thumbprint, secondary_thumbprint ): """Creates a module identity for a device on IoTHub using X509 authentication. - :param str device_id: The name (deviceId) of the device. - :param str module_id: The name (moduleID) of the module. + :param str device_id: The name (Id) of the device. + :param str module_id: The name (Id) of the module. :param str managed_by: The name of the manager device (edge). :param str primary_thumbprint: Primary X509 thumbprint. :param str secondary_thumbprint: Secondary X509 thumbprint. - :param str status: Initital state of the created device (enabled or disabled). - :raises: HttpOperationError if the HTTP response status is not in [200]. + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. :returns: Module object containing the created module. """ @@ -293,24 +311,23 @@ class IoTHubRegistryManager(object): "device_id": device_id, "module_id": module_id, "managed_by": managed_by, - "status": status, "authentication": AuthenticationMechanism( type="selfSigned", x509_thumbprint=x509_thumbprint ), } module = Module(**kwargs) - return self.protocol.service.create_or_update_device(device_id, module_id, module) + return self.protocol.registry_manager.create_or_update_module(device_id, module_id, module) - def create_module_with_certificate_authority(self, device_id, module_id, managed_by, status): + def create_module_with_certificate_authority(self, device_id, module_id, managed_by): """Creates a module identity for a device on IoTHub using certificate authority. - :param str device_id: The name (deviceId) of the device. - :param str module_id: The name (moduleID) of the module. + :param str device_id: The name (Id) of the device. + :param str module_id: The name (Id) of the module. :param str managed_by: The name of the manager device (edge). - :param str status: Initital state of the created device (enabled or disabled). - :raises: HttpOperationError if the HTTP response status is not in [200]. + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. :returns: Module object containing the created module. """ @@ -318,27 +335,26 @@ class IoTHubRegistryManager(object): "device_id": device_id, "module_id": module_id, "managed_by": managed_by, - "status": status, "authentication": AuthenticationMechanism(type="certificateAuthority"), } module = Module(**kwargs) - return self.protocol.service.create_or_update_device(device_id, module_id, module) + return self.protocol.registry_manager.create_or_update_module(device_id, module_id, module) def update_module_with_sas( - self, device_id, module_id, managed_by, etag, primary_key, secondary_key, status + self, device_id, module_id, managed_by, etag, primary_key, secondary_key ): """Updates a module identity for a device on IoTHub using SAS authentication. - :param str device_id: The name (deviceId) of the device. - :param str module_id: The name (moduleID) of the module. + :param str device_id: The name (Id) of the device. + :param str module_id: The name (Id) of the module. :param str managed_by: The name of the manager device (edge). :param str etag: The etag (if_match) value to use for the update operation. :param str primary_key: Primary authentication key. :param str secondary_key: Secondary authentication key. - :param str status: Initital state of the created device (enabled or disabled). - :raises: HttpOperationError if the HTTP response status is not in [200]. + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. :returns: The updated Module object containing the created module. """ @@ -348,35 +364,29 @@ class IoTHubRegistryManager(object): "device_id": device_id, "module_id": module_id, "managed_by": managed_by, - "status": status, "etag": etag, "authentication": AuthenticationMechanism(type="sas", symmetric_key=symmetric_key), } module = Module(**kwargs) - return self.protocol.service.create_or_update_device(device_id, module_id, module, "*") + return self.protocol.registry_manager.create_or_update_module( + device_id, module_id, module, "*" + ) def update_module_with_x509( - self, - device_id, - module_id, - managed_by, - etag, - primary_thumbprint, - secondary_thumbprint, - status, + self, device_id, module_id, managed_by, etag, primary_thumbprint, secondary_thumbprint ): """Updates a module identity for a device on IoTHub using X509 authentication. - :param str device_id: The name (deviceId) of the device. - :param str module_id: The name (moduleID) of the module. + :param str device_id: The name (Id) of the device. + :param str module_id: The name (Id) of the module. :param str managed_by: The name of the manager device (edge). :param str etag: The etag (if_match) value to use for the update operation. :param str primary_thumbprint: Primary X509 thumbprint. :param str secondary_thumbprint: Secondary X509 thumbprint. - :param str status: Initital state of the created device (enabled or disabled). - :raises: HttpOperationError if the HTTP response status is not in [200]. + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. :returns: The updated Module object containing the created module. """ @@ -388,7 +398,6 @@ class IoTHubRegistryManager(object): "device_id": device_id, "module_id": module_id, "managed_by": managed_by, - "status": status, "etag": etag, "authentication": AuthenticationMechanism( type="selfSigned", x509_thumbprint=x509_thumbprint @@ -396,20 +405,18 @@ class IoTHubRegistryManager(object): } module = Module(**kwargs) - return self.protocol.service.create_or_update_device(device_id, module_id, module) + return self.protocol.registry_manager.create_or_update_module(device_id, module_id, module) - def update_module_with_certificate_authority( - self, device_id, module_id, managed_by, etag, status - ): + def update_module_with_certificate_authority(self, device_id, module_id, managed_by, etag): """Updates a module identity for a device on IoTHub using certificate authority. - :param str device_id: The name (deviceId) of the device. - :param str module_id: The name (moduleID) of the module. + :param str device_id: The name (Id) of the device. + :param str module_id: The name (Id) of the module. :param str managed_by: The name of the manager device (edge). :param str etag: The etag (if_match) value to use for the update operation. - :param str status: Initital state of the created device (enabled or disabled). - :raises: HttpOperationError if the HTTP response status is not in [200]. + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. :returns: The updated Module object containing the created module. """ @@ -417,36 +424,261 @@ class IoTHubRegistryManager(object): "device_id": device_id, "module_id": module_id, "managed_by": managed_by, - "status": status, "etag": etag, "authentication": AuthenticationMechanism(type="certificateAuthority"), } module = Module(**kwargs) - return self.protocol.service.create_or_update_device(device_id, module_id, module) + return self.protocol.registry_manager.create_or_update_module(device_id, module_id, module) def get_module(self, device_id, module_id): """Retrieves a module identity for a device from IoTHub. - :param str device_id: The name (deviceId) of the device. - :param str module_id: The name (moduleId) of the module. + :param str device_id: The name (Id) of the device. + :param str module_id: The name (Id) of the module. - :raises: HttpOperationError if the HTTP response status is not in [200]. + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. :returns: The Module object containing the requested module. """ - return self.protocol.service.get_module(device_id, module_id) + return self.protocol.registry_manager.get_module(device_id, module_id) - def delete_module(self, device_id, etag=None): + def get_modules(self, device_id): + """Retrieves all module identities on a device. + + :param str device_id: The name (Id) of the device. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: The list[Module] containing all the modules on the device. + """ + return self.protocol.registry_manager.get_modules_on_device(device_id) + + def delete_module(self, device_id, module_id, etag=None): """Deletes a module identity for a device from IoTHub. - :param str device_id: The name (deviceId) of the device. + :param str device_id: The name (Id) of the device. + :param str module_id: The name (Id) of the module. + :param str etag: The etag (if_match) value to use for the delete operation. - :raises: HttpOperationError if the HTTP response status is not in [200]. + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. :returns: None. """ if etag is None: etag = "*" - self.protocol.service.delete_module(device_id, etag) + self.protocol.registry_manager.delete_module(device_id, module_id, etag) + + def get_service_statistics(self): + """Retrieves the IoTHub service statistics. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: The ServiceStatistics object. + """ + return self.protocol.registry_manager.get_service_statistics() + + def get_device_registry_statistics(self): + """Retrieves the IoTHub device registry statistics. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: The RegistryStatistics object. + """ + return self.protocol.registry_manager.get_device_statistics() + + def get_devices(self, max_number_of_devices=None): + """Get the identities of multiple devices from the IoTHub identity + registry. Not recommended. Use the IoTHub query language to retrieve + device twin and device identity information. See + https://docs.microsoft.com/en-us/rest/api/iothub/service/queryiothub + and + https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-query-language + for more information. + + :param int max_number_of_devices: This parameter when specified, defines the maximum number + of device identities that are returned. Any value outside the range of + 1-1000 is considered to be 1000 + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: List of device info. + """ + return self.protocol.registry_manager.get_devices(max_number_of_devices) + + def bulk_create_or_update_devices(self, devices): + """Create, update, or delete the identities of multiple devices from the + IoTHub identity registry. + + Create, update, or delete the identities of multiple devices from the + IoTHub identity registry. A device identity can be specified only once + in the list. Different operations (create, update, delete) on different + devices are allowed. A maximum of 100 devices can be specified per + invocation. For large scale operations, consider using the import + feature using blob + storage(https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-identity-registry#import-and-export-device-identities). + + :param list[ExportImportDevice] devices: The list of device objects to operate on. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: The BulkRegistryOperationResult object. + """ + return self.protocol.registry_manager.bulk_device_crud(devices) + + def query_iot_hub(self, query_specification, continuation_token=None, max_item_count=None): + """Query an IoTHub to retrieve information regarding device twins using a + SQL-like language. + See https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-query-language + for more information. Pagination of results is supported. This returns + information about device twins only. + + :param QuerySpecification query: The query specification. + :param str continuation_token: Continuation token for paging + :param str max_item_count: Maximum number of requested device twins + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: The QueryResult object. + """ + raw_response = self.protocol.registry_manager.query_iot_hub( + query_specification, continuation_token, max_item_count, None, True + ) + + queryResult = QueryResult() + if raw_response.headers: + queryResult.type = raw_response.headers["x-ms-item-type"] + queryResult.continuation_token = raw_response.headers["x-ms-continuation"] + queryResult.items = raw_response.output + + return queryResult + + def get_twin(self, device_id): + """Gets a device twin. + + :param str device_id: The name (Id) of the device. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: The Twin object. + """ + return self.protocol.twin.get_device_twin(device_id) + + def replace_twin(self, device_id, device_twin): + """Replaces tags and desired properties of a device twin. + + :param str device_id: The name (Id) of the device. + :param Twin device_twin: The twin info of the device. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: The Twin object. + """ + return self.protocol.twin.replace_device_twin(device_id, device_twin) + + def update_twin(self, device_id, device_twin, etag): + """Updates tags and desired properties of a device twin. + + :param str device_id: The name (Id) of the device. + :param Twin device_twin: The twin info of the device. + :param str etag: The etag (if_match) value to use for the update operation. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: The Twin object. + """ + return self.protocol.twin.update_device_twin(device_id, device_twin, etag) + + def get_module_twin(self, device_id, module_id): + """Gets a module twin. + + :param str device_id: The name (Id) of the device. + :param str module_id: The name (Id) of the module. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: The Twin object. + """ + return self.protocol.twin.get_module_twin(device_id, module_id) + + def replace_module_twin(self, device_id, module_id, module_twin): + """Replaces tags and desired properties of a module twin. + + :param str device_id: The name (Id) of the device. + :param str module_id: The name (Id) of the module. + :param Twin module_twin: The twin info of the module. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: The Twin object. + """ + return self.protocol.twin.replace_module_twin(device_id, module_id, module_twin) + + def update_module_twin(self, device_id, module_id, module_twin, etag): + """Updates tags and desired properties of a module twin. + + :param str device_id: The name (Id) of the device. + :param str module_id: The name (Id) of the module. + :param Twin module_twin: The twin info of the module. + :param str etag: The etag (if_match) value to use for the update operation. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: The Twin object. + """ + return self.protocol.twin.update_module_twin(device_id, module_id, module_twin, etag) + + def invoke_device_method(self, device_id, direct_method_request): + """Invoke a direct method on a device. + + :param str device_id: The name (Id) of the device. + :param CloudToDeviceMethod direct_method_request: The method request. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: The CloudToDeviceMethodResult object. + """ + return self.protocol.device_method.invoke_device_method(device_id, direct_method_request) + + def invoke_device_module_method(self, device_id, module_id, direct_method_request): + """Invoke a direct method on a device. + + :param str device_id: The name (Id) of the device. + :param str module_id: The name (Id) of the module. + :param CloudToDeviceMethod direct_method_request: The method request. + + :raises: `HttpOperationError` + if the HTTP response status is not in [200]. + + :returns: The CloudToDeviceMethodResult object. + """ + return self.protocol.device_method.invoke_module_method( + device_id, module_id, direct_method_request + ) + + def send_c2d_message(self, device_id, message): + """Send a C2D mesage to a IoTHub Device. + + :param str device_id: The name (Id) of the device. + :param str message: The message that is to be delievered to the device. + + :raises: Exception if the Send command is not able to send the message + """ + + self.amqp_svc_client.send_message_to_device(device_id, message) diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/edge_pipeline.py b/azure-iot-hub/azure/iot/hub/models.py similarity index 58% rename from azure-iot-device/azure/iot/device/iothub/pipeline/edge_pipeline.py rename to azure-iot-hub/azure/iot/hub/models.py index eddd7aa9f..23b57087b 100644 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/edge_pipeline.py +++ b/azure-iot-hub/azure/iot/hub/models.py @@ -1,18 +1,9 @@ -# -------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- - -import logging - -logger = logging.getLogger(__name__) - - -class EdgePipeline(object): - """Pipeline to communicate with Edge. - Uses HTTP. - """ - - def __init__(self, auth_provider): - pass +"""This module imports and re-exposes the contents of the .protocol.models +subpacakge to be better exposed through the user API surface +""" +from .protocol.models import * diff --git a/azure-iot-hub/azure/iot/hub/protocol/__init__.py b/azure-iot-hub/azure/iot/hub/protocol/__init__.py index c2fa6e1f5..7b51cacb0 100644 --- a/azure-iot-hub/azure/iot/hub/protocol/__init__.py +++ b/azure-iot-hub/azure/iot/hub/protocol/__init__.py @@ -5,9 +5,9 @@ # regenerated. # -------------------------------------------------------------------------- -from .iot_hub_gateway_service_ap_is20190701_preview import IotHubGatewayServiceAPIs20190701Preview +from .iot_hub_gateway_service_ap_is import IotHubGatewayServiceAPIs from .version import VERSION -__all__ = ["IotHubGatewayServiceAPIs20190701Preview"] +__all__ = ["IotHubGatewayServiceAPIs"] __version__ = VERSION diff --git a/azure-iot-hub/azure/iot/hub/protocol/iot_hub_gateway_service_ap_is.py b/azure-iot-hub/azure/iot/hub/protocol/iot_hub_gateway_service_ap_is.py new file mode 100644 index 000000000..4f4016100 --- /dev/null +++ b/azure-iot-hub/azure/iot/hub/protocol/iot_hub_gateway_service_ap_is.py @@ -0,0 +1,102 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is +# regenerated. +# -------------------------------------------------------------------------- + +from msrest.service_client import SDKClient +from msrest import Configuration, Serializer, Deserializer +from .version import VERSION +from msrest.exceptions import HttpOperationError +from .operations.configuration_operations import ConfigurationOperations +from .operations.registry_manager_operations import RegistryManagerOperations +from .operations.job_client_operations import JobClientOperations +from .operations.fault_injection_operations import FaultInjectionOperations +from .operations.twin_operations import TwinOperations +from .operations.http_runtime_operations import HttpRuntimeOperations +from .operations.device_method_operations import DeviceMethodOperations +from . import models + + +class IotHubGatewayServiceAPIsConfiguration(Configuration): + """Configuration for IotHubGatewayServiceAPIs + Note that all parameters used to create this instance are saved as instance + attributes. + + :param credentials: Subscription credentials which uniquely identify + client subscription. + :type credentials: None + :param str base_url: Service URL + """ + + def __init__(self, credentials, base_url=None): + + if credentials is None: + raise ValueError("Parameter 'credentials' must not be None.") + if not base_url: + base_url = "https://fully-qualified-iothubname.azure-devices.net" + + super(IotHubGatewayServiceAPIsConfiguration, self).__init__(base_url) + + self.add_user_agent("iothubgatewayserviceapis/{}".format(VERSION)) + + self.credentials = credentials + + +class IotHubGatewayServiceAPIs(SDKClient): + """IotHubGatewayServiceAPIs + + :ivar config: Configuration for client. + :vartype config: IotHubGatewayServiceAPIsConfiguration + + :ivar configuration: Configuration operations + :vartype configuration: protocol.operations.ConfigurationOperations + :ivar registry_manager: RegistryManager operations + :vartype registry_manager: protocol.operations.RegistryManagerOperations + :ivar job_client: JobClient operations + :vartype job_client: protocol.operations.JobClientOperations + :ivar fault_injection: FaultInjection operations + :vartype fault_injection: protocol.operations.FaultInjectionOperations + :ivar twin: Twin operations + :vartype twin: protocol.operations.TwinOperations + :ivar http_runtime: HttpRuntime operations + :vartype http_runtime: protocol.operations.HttpRuntimeOperations + :ivar device_method: DeviceMethod operations + :vartype device_method: protocol.operations.DeviceMethodOperations + + :param credentials: Subscription credentials which uniquely identify + client subscription. + :type credentials: None + :param str base_url: Service URL + """ + + def __init__(self, credentials, base_url=None): + + self.config = IotHubGatewayServiceAPIsConfiguration(credentials, base_url) + super(IotHubGatewayServiceAPIs, self).__init__(self.config.credentials, self.config) + + client_models = {k: v for k, v in models.__dict__.items() if isinstance(v, type)} + self.api_version = "2020-03-13" + self._serialize = Serializer(client_models) + self._deserialize = Deserializer(client_models) + + self.configuration = ConfigurationOperations( + self._client, self.config, self._serialize, self._deserialize + ) + self.registry_manager = RegistryManagerOperations( + self._client, self.config, self._serialize, self._deserialize + ) + self.job_client = JobClientOperations( + self._client, self.config, self._serialize, self._deserialize + ) + self.fault_injection = FaultInjectionOperations( + self._client, self.config, self._serialize, self._deserialize + ) + self.twin = TwinOperations(self._client, self.config, self._serialize, self._deserialize) + self.http_runtime = HttpRuntimeOperations( + self._client, self.config, self._serialize, self._deserialize + ) + self.device_method = DeviceMethodOperations( + self._client, self.config, self._serialize, self._deserialize + ) diff --git a/azure-iot-hub/azure/iot/hub/protocol/models/__init__.py b/azure-iot-hub/azure/iot/hub/protocol/models/__init__.py index 7d5ff5d7f..db0453b38 100644 --- a/azure-iot-hub/azure/iot/hub/protocol/models/__init__.py +++ b/azure-iot-hub/azure/iot/hub/protocol/models/__init__.py @@ -24,13 +24,12 @@ try: from .device_registry_operation_warning_py3 import DeviceRegistryOperationWarning from .bulk_registry_operation_result_py3 import BulkRegistryOperationResult from .query_specification_py3 import QuerySpecification - from .query_result_py3 import QueryResult + from .twin_properties_py3 import TwinProperties + from .twin_py3 import Twin from .job_properties_py3 import JobProperties from .purge_message_queue_result_py3 import PurgeMessageQueueResult from .fault_injection_connection_properties_py3 import FaultInjectionConnectionProperties from .fault_injection_properties_py3 import FaultInjectionProperties - from .twin_properties_py3 import TwinProperties - from .twin_py3 import Twin from .desired_state_py3 import DesiredState from .reported_py3 import Reported from .desired_py3 import Desired @@ -41,6 +40,7 @@ try: from .job_request_py3 import JobRequest from .device_job_statistics_py3 import DeviceJobStatistics from .job_response_py3 import JobResponse + from .query_result_py3 import QueryResult from .module_py3 import Module from .cloud_to_device_method_result_py3 import CloudToDeviceMethodResult from .digital_twin_interfaces_patch_interfaces_value_properties_value_desired_py3 import ( @@ -72,13 +72,12 @@ except (SyntaxError, ImportError): from .device_registry_operation_warning import DeviceRegistryOperationWarning from .bulk_registry_operation_result import BulkRegistryOperationResult from .query_specification import QuerySpecification - from .query_result import QueryResult + from .twin_properties import TwinProperties + from .twin import Twin from .job_properties import JobProperties from .purge_message_queue_result import PurgeMessageQueueResult from .fault_injection_connection_properties import FaultInjectionConnectionProperties from .fault_injection_properties import FaultInjectionProperties - from .twin_properties import TwinProperties - from .twin import Twin from .desired_state import DesiredState from .reported import Reported from .desired import Desired @@ -89,6 +88,7 @@ except (SyntaxError, ImportError): from .job_request import JobRequest from .device_job_statistics import DeviceJobStatistics from .job_response import JobResponse + from .query_result import QueryResult from .module import Module from .cloud_to_device_method_result import CloudToDeviceMethodResult from .digital_twin_interfaces_patch_interfaces_value_properties_value_desired import ( @@ -121,13 +121,12 @@ __all__ = [ "DeviceRegistryOperationWarning", "BulkRegistryOperationResult", "QuerySpecification", - "QueryResult", + "TwinProperties", + "Twin", "JobProperties", "PurgeMessageQueueResult", "FaultInjectionConnectionProperties", "FaultInjectionProperties", - "TwinProperties", - "Twin", "DesiredState", "Reported", "Desired", @@ -138,6 +137,7 @@ __all__ = [ "JobRequest", "DeviceJobStatistics", "JobResponse", + "QueryResult", "Module", "CloudToDeviceMethodResult", "DigitalTwinInterfacesPatchInterfacesValuePropertiesValueDesired", diff --git a/azure-iot-hub/azure/iot/hub/protocol/models/device_registry_operation_error.py b/azure-iot-hub/azure/iot/hub/protocol/models/device_registry_operation_error.py index c80f4aa26..2b603cba4 100644 --- a/azure-iot-hub/azure/iot/hub/protocol/models/device_registry_operation_error.py +++ b/azure-iot-hub/azure/iot/hub/protocol/models/device_registry_operation_error.py @@ -28,7 +28,8 @@ class DeviceRegistryOperationError(Model): 'RequestTimedOut', 'UnsupportedOperationOnReplica', 'NullMessage', 'ConnectionForcefullyClosedOnNewConnection', 'InvalidDeviceScope', 'ConnectionForcefullyClosedOnFaultInjection', - 'ConnectionRejectedOnFaultInjection', 'InvalidRouteTestInput', + 'ConnectionRejectedOnFaultInjection', 'InvalidEndpointAuthenticationType', + 'ManagedIdentityNotEnabled', 'InvalidRouteTestInput', 'InvalidSourceOnRoute', 'RoutingNotEnabled', 'InvalidContentEncodingOrType', 'InvalidEndorsementKey', 'InvalidRegistrationId', 'InvalidStorageRootKey', @@ -41,15 +42,16 @@ class DeviceRegistryOperationError(Model): 'CannotModifyImmutableConfigurationContent', 'InvalidConfigurationCustomMetricsQuery', 'InvalidPnPInterfaceDefinition', 'InvalidPnPDesiredProperties', 'InvalidPnPReportedProperties', - 'InvalidPnPWritableReportedProperties', 'GenericUnauthorized', - 'IotHubNotFound', 'IotHubUnauthorizedAccess', 'IotHubUnauthorized', - 'ElasticPoolNotFound', 'SystemModuleModifyUnauthorizedAccess', - 'GenericForbidden', 'IotHubSuspended', 'IotHubQuotaExceeded', - 'JobQuotaExceeded', 'DeviceMaximumQueueDepthExceeded', - 'IotHubMaxCbsTokenExceeded', 'DeviceMaximumActiveFileUploadLimitExceeded', + 'InvalidPnPWritableReportedProperties', 'InvalidDigitalTwinJsonPatch', + 'GenericUnauthorized', 'IotHubNotFound', 'IotHubUnauthorizedAccess', + 'IotHubUnauthorized', 'ElasticPoolNotFound', + 'SystemModuleModifyUnauthorizedAccess', 'GenericForbidden', + 'IotHubSuspended', 'IotHubQuotaExceeded', 'JobQuotaExceeded', + 'DeviceMaximumQueueDepthExceeded', 'IotHubMaxCbsTokenExceeded', + 'DeviceMaximumActiveFileUploadLimitExceeded', 'DeviceMaximumQueueSizeExceeded', 'RoutingEndpointResponseForbidden', 'InvalidMessageExpiryTime', 'OperationNotAvailableInCurrentTier', - 'DeviceModelMaxPropertiesExceeded', + 'KeyEncryptionKeyRevoked', 'DeviceModelMaxPropertiesExceeded', 'DeviceModelMaxIndexablePropertiesExceeded', 'IotDpsSuspended', 'IotDpsSuspending', 'GenericNotFound', 'DeviceNotFound', 'JobNotFound', 'QuotaMetricNotFound', 'SystemPropertyNotFound', 'AmqpAddressNotFound', @@ -82,12 +84,13 @@ class DeviceRegistryOperationError(Model): 'InflightMessagesInLink', 'GenericRequestEntityTooLarge', 'MessageTooLarge', 'TooManyDevices', 'TooManyModulesOnDevice', 'ConfigurationCountLimitExceeded', 'DigitalTwinModelCountLimitExceeded', + 'InterfaceNameCompressionModelCountLimitExceeded', 'GenericUnsupportedMediaType', 'IncompatibleDataType', 'GenericTooManyRequests', 'ThrottlingException', 'ThrottleBacklogLimitExceeded', 'ThrottlingBacklogTimeout', - 'ThrottlingMaxActiveJobCountExceeded', 'ClientClosedRequest', - 'GenericServerError', 'ServerError', 'JobCancelled', - 'StatisticsRetrievalError', 'ConnectionForcefullyClosed', + 'ThrottlingMaxActiveJobCountExceeded', 'DeviceThrottlingLimitExceeded', + 'ClientClosedRequest', 'GenericServerError', 'ServerError', + 'JobCancelled', 'StatisticsRetrievalError', 'ConnectionForcefullyClosed', 'InvalidBlobState', 'BackupTimedOut', 'AzureStorageTimeout', 'GenericTimeout', 'InvalidThrottleParameter', 'EventHubLinkAlreadyClosed', 'ReliableBlobStoreError', 'RetryAttemptsExhausted', @@ -95,14 +98,17 @@ class DeviceRegistryOperationError(Model): 'DocumentDbInvalidReturnValue', 'ReliableDocDbStoreStoreError', 'ReliableBlobStoreTimeoutError', 'ConfigReadFailed', 'InvalidContainerReceiveLink', 'InvalidPartitionEpoch', 'RestoreTimedOut', - 'StreamReservationFailure', 'UnexpectedPropertyValue', - 'OrchestrationOperationFailed', 'ModelRepoEndpointError', - 'ResolutionError', 'GenericBadGateway', 'InvalidResponseWhileProxying', + 'StreamReservationFailure', 'SerializationError', + 'UnexpectedPropertyValue', 'OrchestrationOperationFailed', + 'ModelRepoEndpointError', 'ResolutionError', 'UnableToFetchCredentials', + 'UnableToFetchTenantInfo', 'UnableToShareIdentity', + 'UnableToExpandDiscoveryInfo', 'UnableToExpandComponentInfo', + 'GenericBadGateway', 'InvalidResponseWhileProxying', 'GenericServiceUnavailable', 'ServiceUnavailable', 'PartitionNotFound', 'IotHubActivationFailed', 'ServerBusy', 'IotHubRestoring', 'ReceiveLinkOpensThrottled', 'ConnectionUnavailable', 'DeviceUnavailable', - 'ConfigurationNotAvailable', 'GroupNotAvailable', 'GenericGatewayTimeout', - 'GatewayTimeout' + 'ConfigurationNotAvailable', 'GroupNotAvailable', + 'HostingServiceNotAvailable', 'GenericGatewayTimeout', 'GatewayTimeout' :type error_code: str or ~protocol.models.enum :param error_status: Additional details associated with the error. :type error_status: str diff --git a/azure-iot-hub/azure/iot/hub/protocol/models/device_registry_operation_error_py3.py b/azure-iot-hub/azure/iot/hub/protocol/models/device_registry_operation_error_py3.py index 2342298ca..f37605aca 100644 --- a/azure-iot-hub/azure/iot/hub/protocol/models/device_registry_operation_error_py3.py +++ b/azure-iot-hub/azure/iot/hub/protocol/models/device_registry_operation_error_py3.py @@ -28,7 +28,8 @@ class DeviceRegistryOperationError(Model): 'RequestTimedOut', 'UnsupportedOperationOnReplica', 'NullMessage', 'ConnectionForcefullyClosedOnNewConnection', 'InvalidDeviceScope', 'ConnectionForcefullyClosedOnFaultInjection', - 'ConnectionRejectedOnFaultInjection', 'InvalidRouteTestInput', + 'ConnectionRejectedOnFaultInjection', 'InvalidEndpointAuthenticationType', + 'ManagedIdentityNotEnabled', 'InvalidRouteTestInput', 'InvalidSourceOnRoute', 'RoutingNotEnabled', 'InvalidContentEncodingOrType', 'InvalidEndorsementKey', 'InvalidRegistrationId', 'InvalidStorageRootKey', @@ -41,15 +42,16 @@ class DeviceRegistryOperationError(Model): 'CannotModifyImmutableConfigurationContent', 'InvalidConfigurationCustomMetricsQuery', 'InvalidPnPInterfaceDefinition', 'InvalidPnPDesiredProperties', 'InvalidPnPReportedProperties', - 'InvalidPnPWritableReportedProperties', 'GenericUnauthorized', - 'IotHubNotFound', 'IotHubUnauthorizedAccess', 'IotHubUnauthorized', - 'ElasticPoolNotFound', 'SystemModuleModifyUnauthorizedAccess', - 'GenericForbidden', 'IotHubSuspended', 'IotHubQuotaExceeded', - 'JobQuotaExceeded', 'DeviceMaximumQueueDepthExceeded', - 'IotHubMaxCbsTokenExceeded', 'DeviceMaximumActiveFileUploadLimitExceeded', + 'InvalidPnPWritableReportedProperties', 'InvalidDigitalTwinJsonPatch', + 'GenericUnauthorized', 'IotHubNotFound', 'IotHubUnauthorizedAccess', + 'IotHubUnauthorized', 'ElasticPoolNotFound', + 'SystemModuleModifyUnauthorizedAccess', 'GenericForbidden', + 'IotHubSuspended', 'IotHubQuotaExceeded', 'JobQuotaExceeded', + 'DeviceMaximumQueueDepthExceeded', 'IotHubMaxCbsTokenExceeded', + 'DeviceMaximumActiveFileUploadLimitExceeded', 'DeviceMaximumQueueSizeExceeded', 'RoutingEndpointResponseForbidden', 'InvalidMessageExpiryTime', 'OperationNotAvailableInCurrentTier', - 'DeviceModelMaxPropertiesExceeded', + 'KeyEncryptionKeyRevoked', 'DeviceModelMaxPropertiesExceeded', 'DeviceModelMaxIndexablePropertiesExceeded', 'IotDpsSuspended', 'IotDpsSuspending', 'GenericNotFound', 'DeviceNotFound', 'JobNotFound', 'QuotaMetricNotFound', 'SystemPropertyNotFound', 'AmqpAddressNotFound', @@ -82,12 +84,13 @@ class DeviceRegistryOperationError(Model): 'InflightMessagesInLink', 'GenericRequestEntityTooLarge', 'MessageTooLarge', 'TooManyDevices', 'TooManyModulesOnDevice', 'ConfigurationCountLimitExceeded', 'DigitalTwinModelCountLimitExceeded', + 'InterfaceNameCompressionModelCountLimitExceeded', 'GenericUnsupportedMediaType', 'IncompatibleDataType', 'GenericTooManyRequests', 'ThrottlingException', 'ThrottleBacklogLimitExceeded', 'ThrottlingBacklogTimeout', - 'ThrottlingMaxActiveJobCountExceeded', 'ClientClosedRequest', - 'GenericServerError', 'ServerError', 'JobCancelled', - 'StatisticsRetrievalError', 'ConnectionForcefullyClosed', + 'ThrottlingMaxActiveJobCountExceeded', 'DeviceThrottlingLimitExceeded', + 'ClientClosedRequest', 'GenericServerError', 'ServerError', + 'JobCancelled', 'StatisticsRetrievalError', 'ConnectionForcefullyClosed', 'InvalidBlobState', 'BackupTimedOut', 'AzureStorageTimeout', 'GenericTimeout', 'InvalidThrottleParameter', 'EventHubLinkAlreadyClosed', 'ReliableBlobStoreError', 'RetryAttemptsExhausted', @@ -95,14 +98,17 @@ class DeviceRegistryOperationError(Model): 'DocumentDbInvalidReturnValue', 'ReliableDocDbStoreStoreError', 'ReliableBlobStoreTimeoutError', 'ConfigReadFailed', 'InvalidContainerReceiveLink', 'InvalidPartitionEpoch', 'RestoreTimedOut', - 'StreamReservationFailure', 'UnexpectedPropertyValue', - 'OrchestrationOperationFailed', 'ModelRepoEndpointError', - 'ResolutionError', 'GenericBadGateway', 'InvalidResponseWhileProxying', + 'StreamReservationFailure', 'SerializationError', + 'UnexpectedPropertyValue', 'OrchestrationOperationFailed', + 'ModelRepoEndpointError', 'ResolutionError', 'UnableToFetchCredentials', + 'UnableToFetchTenantInfo', 'UnableToShareIdentity', + 'UnableToExpandDiscoveryInfo', 'UnableToExpandComponentInfo', + 'GenericBadGateway', 'InvalidResponseWhileProxying', 'GenericServiceUnavailable', 'ServiceUnavailable', 'PartitionNotFound', 'IotHubActivationFailed', 'ServerBusy', 'IotHubRestoring', 'ReceiveLinkOpensThrottled', 'ConnectionUnavailable', 'DeviceUnavailable', - 'ConfigurationNotAvailable', 'GroupNotAvailable', 'GenericGatewayTimeout', - 'GatewayTimeout' + 'ConfigurationNotAvailable', 'GroupNotAvailable', + 'HostingServiceNotAvailable', 'GenericGatewayTimeout', 'GatewayTimeout' :type error_code: str or ~protocol.models.enum :param error_status: Additional details associated with the error. :type error_status: str diff --git a/azure-iot-hub/azure/iot/hub/protocol/models/job_properties.py b/azure-iot-hub/azure/iot/hub/protocol/models/job_properties.py index 5978959cf..ae99d48fa 100644 --- a/azure-iot-hub/azure/iot/hub/protocol/models/job_properties.py +++ b/azure-iot-hub/azure/iot/hub/protocol/models/job_properties.py @@ -49,6 +49,10 @@ class JobProperties(Model): jobs. Default: false. If false, authorization keys are included in export output. Keys are exported as null otherwise. :type exclude_keys_in_export: bool + :param storage_authentication_type: Specifies authentication type being + used for connecting to storage account. Possible values include: + 'keyBased', 'identityBased' + :type storage_authentication_type: str or ~protocol.models.enum :param failure_reason: System genereated. Ignored at creation. If status == failure, this represents a string containing the reason. :type failure_reason: str @@ -66,6 +70,7 @@ class JobProperties(Model): "output_blob_container_uri": {"key": "outputBlobContainerUri", "type": "str"}, "output_blob_name": {"key": "outputBlobName", "type": "str"}, "exclude_keys_in_export": {"key": "excludeKeysInExport", "type": "bool"}, + "storage_authentication_type": {"key": "storageAuthenticationType", "type": "str"}, "failure_reason": {"key": "failureReason", "type": "str"}, } @@ -82,4 +87,5 @@ class JobProperties(Model): self.output_blob_container_uri = kwargs.get("output_blob_container_uri", None) self.output_blob_name = kwargs.get("output_blob_name", None) self.exclude_keys_in_export = kwargs.get("exclude_keys_in_export", None) + self.storage_authentication_type = kwargs.get("storage_authentication_type", None) self.failure_reason = kwargs.get("failure_reason", None) diff --git a/azure-iot-hub/azure/iot/hub/protocol/models/job_properties_py3.py b/azure-iot-hub/azure/iot/hub/protocol/models/job_properties_py3.py index 45f0146ce..f2432c0fa 100644 --- a/azure-iot-hub/azure/iot/hub/protocol/models/job_properties_py3.py +++ b/azure-iot-hub/azure/iot/hub/protocol/models/job_properties_py3.py @@ -49,6 +49,10 @@ class JobProperties(Model): jobs. Default: false. If false, authorization keys are included in export output. Keys are exported as null otherwise. :type exclude_keys_in_export: bool + :param storage_authentication_type: Specifies authentication type being + used for connecting to storage account. Possible values include: + 'keyBased', 'identityBased' + :type storage_authentication_type: str or ~protocol.models.enum :param failure_reason: System genereated. Ignored at creation. If status == failure, this represents a string containing the reason. :type failure_reason: str @@ -66,6 +70,7 @@ class JobProperties(Model): "output_blob_container_uri": {"key": "outputBlobContainerUri", "type": "str"}, "output_blob_name": {"key": "outputBlobName", "type": "str"}, "exclude_keys_in_export": {"key": "excludeKeysInExport", "type": "bool"}, + "storage_authentication_type": {"key": "storageAuthenticationType", "type": "str"}, "failure_reason": {"key": "failureReason", "type": "str"}, } @@ -83,6 +88,7 @@ class JobProperties(Model): output_blob_container_uri: str = None, output_blob_name: str = None, exclude_keys_in_export: bool = None, + storage_authentication_type=None, failure_reason: str = None, **kwargs ) -> None: @@ -98,4 +104,5 @@ class JobProperties(Model): self.output_blob_container_uri = output_blob_container_uri self.output_blob_name = output_blob_name self.exclude_keys_in_export = exclude_keys_in_export + self.storage_authentication_type = storage_authentication_type self.failure_reason = failure_reason diff --git a/azure-iot-hub/azure/iot/hub/protocol/operations/__init__.py b/azure-iot-hub/azure/iot/hub/protocol/operations/__init__.py index f443566d3..a41b5ac49 100644 --- a/azure-iot-hub/azure/iot/hub/protocol/operations/__init__.py +++ b/azure-iot-hub/azure/iot/hub/protocol/operations/__init__.py @@ -5,7 +5,20 @@ # regenerated. # -------------------------------------------------------------------------- -from .service_operations import ServiceOperations -from .digital_twin_operations import DigitalTwinOperations +from .configuration_operations import ConfigurationOperations +from .registry_manager_operations import RegistryManagerOperations +from .job_client_operations import JobClientOperations +from .fault_injection_operations import FaultInjectionOperations +from .twin_operations import TwinOperations +from .http_runtime_operations import HttpRuntimeOperations +from .device_method_operations import DeviceMethodOperations -__all__ = ["ServiceOperations", "DigitalTwinOperations"] +__all__ = [ + "ConfigurationOperations", + "RegistryManagerOperations", + "JobClientOperations", + "FaultInjectionOperations", + "TwinOperations", + "HttpRuntimeOperations", + "DeviceMethodOperations", +] diff --git a/azure-iot-hub/azure/iot/hub/protocol/operations/configuration_operations.py b/azure-iot-hub/azure/iot/hub/protocol/operations/configuration_operations.py new file mode 100644 index 000000000..20f22c699 --- /dev/null +++ b/azure-iot-hub/azure/iot/hub/protocol/operations/configuration_operations.py @@ -0,0 +1,390 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is +# regenerated. +# -------------------------------------------------------------------------- + +from msrest.pipeline import ClientRawResponse +from msrest.exceptions import HttpOperationError + +from .. import models + + +class ConfigurationOperations(object): + """ConfigurationOperations operations. + + :param client: Client for service requests. + :param config: Configuration of service client. + :param serializer: An object model serializer. + :param deserializer: An object model deserializer. + :ivar api_version: Version of the Api. Constant value: '2020-03-13'. + """ + + models = models + + def __init__(self, client, config, serializer, deserializer): + + self._client = client + self._serialize = serializer + self._deserialize = deserializer + + self.config = config + self.api_version = "2020-03-13" + + def get(self, id, custom_headers=None, raw=False, **operation_config): + """Retrieve a configuration for Iot Hub devices and modules by it + identifier. + + :param id: + :type id: str + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: Configuration or ClientRawResponse if raw=true + :rtype: ~protocol.models.Configuration or + ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.get.metadata["url"] + path_format_arguments = {"id": self._serialize.url("id", id, "str")} + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + if custom_headers: + header_parameters.update(custom_headers) + + # Construct and send request + request = self._client.get(url, query_parameters, header_parameters) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("Configuration", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + get.metadata = {"url": "/configurations/{id}"} + + def create_or_update( + self, id, configuration, if_match=None, custom_headers=None, raw=False, **operation_config + ): + """Create or update the configuration for devices or modules of an IoT + hub. An ETag must not be specified for the create operation. An ETag + must be specified for the update operation. Note that configuration Id + and Content cannot be updated by the user. + + :param id: + :type id: str + :param configuration: + :type configuration: ~protocol.models.Configuration + :param if_match: + :type if_match: str + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: Configuration or ClientRawResponse if raw=true + :rtype: ~protocol.models.Configuration or + ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.create_or_update.metadata["url"] + path_format_arguments = {"id": self._serialize.url("id", id, "str")} + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + header_parameters["Content-Type"] = "application/json; charset=utf-8" + if custom_headers: + header_parameters.update(custom_headers) + if if_match is not None: + header_parameters["If-Match"] = self._serialize.header("if_match", if_match, "str") + + # Construct body + body_content = self._serialize.body(configuration, "Configuration") + + # Construct and send request + request = self._client.put(url, query_parameters, header_parameters, body_content) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200, 201]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("Configuration", response) + if response.status_code == 201: + deserialized = self._deserialize("Configuration", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + create_or_update.metadata = {"url": "/configurations/{id}"} + + def delete(self, id, if_match=None, custom_headers=None, raw=False, **operation_config): + """Delete the configuration for devices or modules of an IoT hub. This + request requires the If-Match header. The client may specify the ETag + for the device identity on the request in order to compare to the ETag + maintained by the service for the purpose of optimistic concurrency. + The delete operation is performed only if the ETag sent by the client + matches the value maintained by the server, indicating that the device + identity has not been modified since it was retrieved by the client. To + force an unconditional delete, set If-Match to the wildcard character + (*). + + :param id: + :type id: str + :param if_match: + :type if_match: str + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: None or ClientRawResponse if raw=true + :rtype: None or ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.delete.metadata["url"] + path_format_arguments = {"id": self._serialize.url("id", id, "str")} + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + if custom_headers: + header_parameters.update(custom_headers) + if if_match is not None: + header_parameters["If-Match"] = self._serialize.header("if_match", if_match, "str") + + # Construct and send request + request = self._client.delete(url, query_parameters, header_parameters) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [204]: + raise HttpOperationError(self._deserialize, response) + + if raw: + client_raw_response = ClientRawResponse(None, response) + return client_raw_response + + delete.metadata = {"url": "/configurations/{id}"} + + def get_configurations(self, top=None, custom_headers=None, raw=False, **operation_config): + """Get multiple configurations for devices or modules of an IoT Hub. + Returns the specified number of configurations for Iot Hub. Pagination + is not supported. + + :param top: + :type top: int + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: list or ClientRawResponse if raw=true + :rtype: list[~protocol.models.Configuration] or + ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.get_configurations.metadata["url"] + + # Construct parameters + query_parameters = {} + if top is not None: + query_parameters["top"] = self._serialize.query("top", top, "int") + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + if custom_headers: + header_parameters.update(custom_headers) + + # Construct and send request + request = self._client.get(url, query_parameters, header_parameters) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("[Configuration]", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + get_configurations.metadata = {"url": "/configurations"} + + def test_queries(self, input, custom_headers=None, raw=False, **operation_config): + """Validates the target condition query and custom metric queries for a + configuration. + + Validates the target condition query and custom metric queries for a + configuration. + + :param input: + :type input: ~protocol.models.ConfigurationQueriesTestInput + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: ConfigurationQueriesTestResponse or ClientRawResponse if + raw=true + :rtype: ~protocol.models.ConfigurationQueriesTestResponse or + ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.test_queries.metadata["url"] + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + header_parameters["Content-Type"] = "application/json; charset=utf-8" + if custom_headers: + header_parameters.update(custom_headers) + + # Construct body + body_content = self._serialize.body(input, "ConfigurationQueriesTestInput") + + # Construct and send request + request = self._client.post(url, query_parameters, header_parameters, body_content) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("ConfigurationQueriesTestResponse", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + test_queries.metadata = {"url": "/configurations/testQueries"} + + def apply_on_edge_device(self, id, content, custom_headers=None, raw=False, **operation_config): + """Applies the provided configuration content to the specified edge + device. + + Applies the provided configuration content to the specified edge + device. Configuration content must have modules content. + + :param id: Device ID. + :type id: str + :param content: Configuration Content. + :type content: ~protocol.models.ConfigurationContent + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: object or ClientRawResponse if raw=true + :rtype: object or ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.apply_on_edge_device.metadata["url"] + path_format_arguments = {"id": self._serialize.url("id", id, "str")} + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + header_parameters["Content-Type"] = "application/json; charset=utf-8" + if custom_headers: + header_parameters.update(custom_headers) + + # Construct body + body_content = self._serialize.body(content, "ConfigurationContent") + + # Construct and send request + request = self._client.post(url, query_parameters, header_parameters, body_content) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200, 204]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("object", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + apply_on_edge_device.metadata = {"url": "/devices/{id}/applyConfigurationContent"} diff --git a/azure-iot-hub/azure/iot/hub/protocol/operations/device_method_operations.py b/azure-iot-hub/azure/iot/hub/protocol/operations/device_method_operations.py new file mode 100644 index 000000000..5522a84ba --- /dev/null +++ b/azure-iot-hub/azure/iot/hub/protocol/operations/device_method_operations.py @@ -0,0 +1,174 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is +# regenerated. +# -------------------------------------------------------------------------- + +from msrest.pipeline import ClientRawResponse +from msrest.exceptions import HttpOperationError + +from .. import models + + +class DeviceMethodOperations(object): + """DeviceMethodOperations operations. + + :param client: Client for service requests. + :param config: Configuration of service client. + :param serializer: An object model serializer. + :param deserializer: An object model deserializer. + :ivar api_version: Version of the Api. Constant value: '2020-03-13'. + """ + + models = models + + def __init__(self, client, config, serializer, deserializer): + + self._client = client + self._serialize = serializer + self._deserialize = deserializer + + self.config = config + self.api_version = "2020-03-13" + + def invoke_device_method( + self, device_id, direct_method_request, custom_headers=None, raw=False, **operation_config + ): + """Invoke a direct method on a device. + + Invoke a direct method on a device. See + https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-direct-methods + for more information. + + :param device_id: + :type device_id: str + :param direct_method_request: + :type direct_method_request: ~protocol.models.CloudToDeviceMethod + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: CloudToDeviceMethodResult or ClientRawResponse if raw=true + :rtype: ~protocol.models.CloudToDeviceMethodResult or + ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.invoke_device_method.metadata["url"] + path_format_arguments = {"deviceId": self._serialize.url("device_id", device_id, "str")} + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + header_parameters["Content-Type"] = "application/json; charset=utf-8" + if custom_headers: + header_parameters.update(custom_headers) + + # Construct body + body_content = self._serialize.body(direct_method_request, "CloudToDeviceMethod") + + # Construct and send request + request = self._client.post(url, query_parameters, header_parameters, body_content) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("CloudToDeviceMethodResult", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + invoke_device_method.metadata = {"url": "/twins/{deviceId}/methods"} + + def invoke_module_method( + self, + device_id, + module_id, + direct_method_request, + custom_headers=None, + raw=False, + **operation_config + ): + """Invoke a direct method on a module of a device. + + Invoke a direct method on a module of a device. See + https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-direct-methods + for more information. + + :param device_id: + :type device_id: str + :param module_id: + :type module_id: str + :param direct_method_request: + :type direct_method_request: ~protocol.models.CloudToDeviceMethod + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: CloudToDeviceMethodResult or ClientRawResponse if raw=true + :rtype: ~protocol.models.CloudToDeviceMethodResult or + ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.invoke_module_method.metadata["url"] + path_format_arguments = { + "deviceId": self._serialize.url("device_id", device_id, "str"), + "moduleId": self._serialize.url("module_id", module_id, "str"), + } + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + header_parameters["Content-Type"] = "application/json; charset=utf-8" + if custom_headers: + header_parameters.update(custom_headers) + + # Construct body + body_content = self._serialize.body(direct_method_request, "CloudToDeviceMethod") + + # Construct and send request + request = self._client.post(url, query_parameters, header_parameters, body_content) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("CloudToDeviceMethodResult", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + invoke_module_method.metadata = {"url": "/twins/{deviceId}/modules/{moduleId}/methods"} diff --git a/azure-iot-hub/azure/iot/hub/protocol/operations/fault_injection_operations.py b/azure-iot-hub/azure/iot/hub/protocol/operations/fault_injection_operations.py new file mode 100644 index 000000000..79a02f0db --- /dev/null +++ b/azure-iot-hub/azure/iot/hub/protocol/operations/fault_injection_operations.py @@ -0,0 +1,128 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is +# regenerated. +# -------------------------------------------------------------------------- + +from msrest.pipeline import ClientRawResponse +from msrest.exceptions import HttpOperationError + +from .. import models + + +class FaultInjectionOperations(object): + """FaultInjectionOperations operations. + + :param client: Client for service requests. + :param config: Configuration of service client. + :param serializer: An object model serializer. + :param deserializer: An object model deserializer. + :ivar api_version: Version of the Api. Constant value: '2020-03-13'. + """ + + models = models + + def __init__(self, client, config, serializer, deserializer): + + self._client = client + self._serialize = serializer + self._deserialize = deserializer + + self.config = config + self.api_version = "2020-03-13" + + def get(self, custom_headers=None, raw=False, **operation_config): + """Get FaultInjection entity. + + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: FaultInjectionProperties or ClientRawResponse if raw=true + :rtype: ~protocol.models.FaultInjectionProperties or + ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.get.metadata["url"] + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + if custom_headers: + header_parameters.update(custom_headers) + + # Construct and send request + request = self._client.get(url, query_parameters, header_parameters) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("FaultInjectionProperties", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + get.metadata = {"url": "/faultInjection"} + + def set(self, value, custom_headers=None, raw=False, **operation_config): + """Create or update FaultInjection entity. + + :param value: + :type value: ~protocol.models.FaultInjectionProperties + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: None or ClientRawResponse if raw=true + :rtype: None or ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.set.metadata["url"] + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Content-Type"] = "application/json; charset=utf-8" + if custom_headers: + header_parameters.update(custom_headers) + + # Construct body + body_content = self._serialize.body(value, "FaultInjectionProperties") + + # Construct and send request + request = self._client.put(url, query_parameters, header_parameters, body_content) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + if raw: + client_raw_response = ClientRawResponse(None, response) + return client_raw_response + + set.metadata = {"url": "/faultInjection"} diff --git a/azure-iot-hub/azure/iot/hub/protocol/operations/http_runtime_operations.py b/azure-iot-hub/azure/iot/hub/protocol/operations/http_runtime_operations.py new file mode 100644 index 000000000..b16e23262 --- /dev/null +++ b/azure-iot-hub/azure/iot/hub/protocol/operations/http_runtime_operations.py @@ -0,0 +1,187 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is +# regenerated. +# -------------------------------------------------------------------------- + +from msrest.pipeline import ClientRawResponse +from msrest.exceptions import HttpOperationError + +from .. import models + + +class HttpRuntimeOperations(object): + """HttpRuntimeOperations operations. + + :param client: Client for service requests. + :param config: Configuration of service client. + :param serializer: An object model serializer. + :param deserializer: An object model deserializer. + :ivar api_version: Version of the Api. Constant value: '2020-03-13'. + """ + + models = models + + def __init__(self, client, config, serializer, deserializer): + + self._client = client + self._serialize = serializer + self._deserialize = deserializer + + self.config = config + self.api_version = "2020-03-13" + + def receive_feedback_notification(self, custom_headers=None, raw=False, **operation_config): + """This method is used to retrieve feedback of a cloud-to-device message. + + This method is used to retrieve feedback of a cloud-to-device message + See https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-messaging + for more information. This capability is only available in the standard + tier IoT Hub. For more information, see [Choose the right IoT Hub + tier](https://aka.ms/scaleyouriotsolution). + + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: None or ClientRawResponse if raw=true + :rtype: None or ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.receive_feedback_notification.metadata["url"] + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + if custom_headers: + header_parameters.update(custom_headers) + + # Construct and send request + request = self._client.get(url, query_parameters, header_parameters) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200, 204]: + raise HttpOperationError(self._deserialize, response) + + if raw: + client_raw_response = ClientRawResponse(None, response) + return client_raw_response + + receive_feedback_notification.metadata = {"url": "/messages/serviceBound/feedback"} + + def complete_feedback_notification( + self, lock_token, custom_headers=None, raw=False, **operation_config + ): + """This method completes a feedback message. + + This method completes a feedback message. The lockToken obtained when + the message was received must be provided to resolve race conditions + when completing, a feedback message. A completed message is deleted + from the feedback queue. See + https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-messaging for + more information. + + :param lock_token: Lock token. + :type lock_token: str + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: None or ClientRawResponse if raw=true + :rtype: None or ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.complete_feedback_notification.metadata["url"] + path_format_arguments = {"lockToken": self._serialize.url("lock_token", lock_token, "str")} + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + if custom_headers: + header_parameters.update(custom_headers) + + # Construct and send request + request = self._client.delete(url, query_parameters, header_parameters) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [204]: + raise HttpOperationError(self._deserialize, response) + + if raw: + client_raw_response = ClientRawResponse(None, response) + return client_raw_response + + complete_feedback_notification.metadata = {"url": "/messages/serviceBound/feedback/{lockToken}"} + + def abandon_feedback_notification( + self, lock_token, custom_headers=None, raw=False, **operation_config + ): + """This method abandons a feedback message. + + This method abandons a feedback message. The lockToken obtained when + the message was received must be provided to resolve race conditions + when abandoning, a feedback message. A abandoned message is deleted + from the feedback queue. See + https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-messaging for + more information. + + :param lock_token: Lock Token. + :type lock_token: str + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: None or ClientRawResponse if raw=true + :rtype: None or ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.abandon_feedback_notification.metadata["url"] + path_format_arguments = {"lockToken": self._serialize.url("lock_token", lock_token, "str")} + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + if custom_headers: + header_parameters.update(custom_headers) + + # Construct and send request + request = self._client.post(url, query_parameters, header_parameters) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [204]: + raise HttpOperationError(self._deserialize, response) + + if raw: + client_raw_response = ClientRawResponse(None, response) + return client_raw_response + + abandon_feedback_notification.metadata = { + "url": "/messages/serviceBound/feedback/{lockToken}/abandon" + } diff --git a/azure-iot-hub/azure/iot/hub/protocol/operations/job_client_operations.py b/azure-iot-hub/azure/iot/hub/protocol/operations/job_client_operations.py new file mode 100644 index 000000000..f8683ba6e --- /dev/null +++ b/azure-iot-hub/azure/iot/hub/protocol/operations/job_client_operations.py @@ -0,0 +1,505 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is +# regenerated. +# -------------------------------------------------------------------------- + +from msrest.pipeline import ClientRawResponse +from msrest.exceptions import HttpOperationError + +from .. import models + + +class JobClientOperations(object): + """JobClientOperations operations. + + :param client: Client for service requests. + :param config: Configuration of service client. + :param serializer: An object model serializer. + :param deserializer: An object model deserializer. + :ivar api_version: Version of the Api. Constant value: '2020-03-13'. + """ + + models = models + + def __init__(self, client, config, serializer, deserializer): + + self._client = client + self._serialize = serializer + self._deserialize = deserializer + + self.config = config + self.api_version = "2020-03-13" + + def create_import_export_job( + self, job_properties, custom_headers=None, raw=False, **operation_config + ): + """Create a new import/export job on an IoT hub. + + Create a new import/export job on an IoT hub. See + https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-identity-registry#import-and-export-device-identities + for more information. + + :param job_properties: Specifies the job specification. + :type job_properties: ~protocol.models.JobProperties + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: JobProperties or ClientRawResponse if raw=true + :rtype: ~protocol.models.JobProperties or + ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.create_import_export_job.metadata["url"] + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + header_parameters["Content-Type"] = "application/json; charset=utf-8" + if custom_headers: + header_parameters.update(custom_headers) + + # Construct body + body_content = self._serialize.body(job_properties, "JobProperties") + + # Construct and send request + request = self._client.post(url, query_parameters, header_parameters, body_content) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("JobProperties", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + create_import_export_job.metadata = {"url": "/jobs/create"} + + def get_import_export_jobs(self, custom_headers=None, raw=False, **operation_config): + """Gets the status of all import/export jobs in an iot hub. + + Gets the status of all import/export jobs in an iot hub. See + https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-identity-registry#import-and-export-device-identities + for more information. + + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: list or ClientRawResponse if raw=true + :rtype: list[~protocol.models.JobProperties] or + ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.get_import_export_jobs.metadata["url"] + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + if custom_headers: + header_parameters.update(custom_headers) + + # Construct and send request + request = self._client.get(url, query_parameters, header_parameters) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("[JobProperties]", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + get_import_export_jobs.metadata = {"url": "/jobs"} + + def get_import_export_job(self, id, custom_headers=None, raw=False, **operation_config): + """Gets the status of an import or export job in an iot hub. + + Gets the status of an import or export job in an iot hub. See + https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-identity-registry#import-and-export-device-identities + for more information. + + :param id: Job ID. + :type id: str + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: JobProperties or ClientRawResponse if raw=true + :rtype: ~protocol.models.JobProperties or + ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.get_import_export_job.metadata["url"] + path_format_arguments = {"id": self._serialize.url("id", id, "str")} + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + if custom_headers: + header_parameters.update(custom_headers) + + # Construct and send request + request = self._client.get(url, query_parameters, header_parameters) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("JobProperties", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + get_import_export_job.metadata = {"url": "/jobs/{id}"} + + def cancel_import_export_job(self, id, custom_headers=None, raw=False, **operation_config): + """Cancels an import or export job in an IoT hub. + + Cancels an import or export job in an IoT hub. See + https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-identity-registry#import-and-export-device-identities + for more information. + + :param id: Job ID. + :type id: str + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: object or ClientRawResponse if raw=true + :rtype: object or ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.cancel_import_export_job.metadata["url"] + path_format_arguments = {"id": self._serialize.url("id", id, "str")} + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + if custom_headers: + header_parameters.update(custom_headers) + + # Construct and send request + request = self._client.delete(url, query_parameters, header_parameters) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200, 204]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("object", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + cancel_import_export_job.metadata = {"url": "/jobs/{id}"} + + def get_job(self, id, custom_headers=None, raw=False, **operation_config): + """Retrieves details of a scheduled job from an IoT hub. + + Retrieves details of a scheduled job from an IoT hub. See + https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-jobs + for more information. + + :param id: Job ID. + :type id: str + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: JobResponse or ClientRawResponse if raw=true + :rtype: ~protocol.models.JobResponse or + ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.get_job.metadata["url"] + path_format_arguments = {"id": self._serialize.url("id", id, "str")} + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + if custom_headers: + header_parameters.update(custom_headers) + + # Construct and send request + request = self._client.get(url, query_parameters, header_parameters) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("JobResponse", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + get_job.metadata = {"url": "/jobs/v2/{id}"} + + def create_job(self, id, job_request, custom_headers=None, raw=False, **operation_config): + """Creates a new job to schedule update twins or device direct methods on + an IoT hub at a scheduled time. + + Creates a new job to schedule update twins or device direct methods on + an IoT hub at a scheduled time. See + https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-jobs + for more information. + + :param id: Job ID. + :type id: str + :param job_request: + :type job_request: ~protocol.models.JobRequest + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: JobResponse or ClientRawResponse if raw=true + :rtype: ~protocol.models.JobResponse or + ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.create_job.metadata["url"] + path_format_arguments = {"id": self._serialize.url("id", id, "str")} + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + header_parameters["Content-Type"] = "application/json; charset=utf-8" + if custom_headers: + header_parameters.update(custom_headers) + + # Construct body + body_content = self._serialize.body(job_request, "JobRequest") + + # Construct and send request + request = self._client.put(url, query_parameters, header_parameters, body_content) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("JobResponse", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + create_job.metadata = {"url": "/jobs/v2/{id}"} + + def cancel_job(self, id, custom_headers=None, raw=False, **operation_config): + """Cancels a scheduled job on an IoT hub. + + Cancels a scheduled job on an IoT hub. See + https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-jobs + for more information. + + :param id: Job ID. + :type id: str + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: JobResponse or ClientRawResponse if raw=true + :rtype: ~protocol.models.JobResponse or + ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.cancel_job.metadata["url"] + path_format_arguments = {"id": self._serialize.url("id", id, "str")} + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + if custom_headers: + header_parameters.update(custom_headers) + + # Construct and send request + request = self._client.post(url, query_parameters, header_parameters) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("JobResponse", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + cancel_job.metadata = {"url": "/jobs/v2/{id}/cancel"} + + def query_jobs( + self, job_type=None, job_status=None, custom_headers=None, raw=False, **operation_config + ): + """Query an IoT hub to retrieve information regarding jobs using the IoT + Hub query language. + + Query an IoT hub to retrieve information regarding jobs using the IoT + Hub query language. See + https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-query-language + for more information. Pagination of results is supported. This returns + information about jobs only. + + :param job_type: Job Type. + :type job_type: str + :param job_status: Job Status. + :type job_status: str + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: QueryResult or ClientRawResponse if raw=true + :rtype: ~protocol.models.QueryResult or + ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.query_jobs.metadata["url"] + + # Construct parameters + query_parameters = {} + if job_type is not None: + query_parameters["jobType"] = self._serialize.query("job_type", job_type, "str") + if job_status is not None: + query_parameters["jobStatus"] = self._serialize.query("job_status", job_status, "str") + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + if custom_headers: + header_parameters.update(custom_headers) + + # Construct and send request + request = self._client.get(url, query_parameters, header_parameters) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("QueryResult", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + query_jobs.metadata = {"url": "/jobs/v2/query"} diff --git a/azure-iot-hub/azure/iot/hub/protocol/operations/registry_manager_operations.py b/azure-iot-hub/azure/iot/hub/protocol/operations/registry_manager_operations.py new file mode 100644 index 000000000..e2a76ecb5 --- /dev/null +++ b/azure-iot-hub/azure/iot/hub/protocol/operations/registry_manager_operations.py @@ -0,0 +1,829 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is +# regenerated. +# -------------------------------------------------------------------------- + +from msrest.pipeline import ClientRawResponse +from msrest.exceptions import HttpOperationError + +from .. import models + + +class RegistryManagerOperations(object): + """RegistryManagerOperations operations. + + :param client: Client for service requests. + :param config: Configuration of service client. + :param serializer: An object model serializer. + :param deserializer: An object model deserializer. + :ivar api_version: Version of the Api. Constant value: '2020-03-13'. + """ + + models = models + + def __init__(self, client, config, serializer, deserializer): + + self._client = client + self._serialize = serializer + self._deserialize = deserializer + + self.config = config + self.api_version = "2020-03-13" + + def get_device_statistics(self, custom_headers=None, raw=False, **operation_config): + """Retrieves statistics about device identities in the IoT hub’s identity + registry. + + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: RegistryStatistics or ClientRawResponse if raw=true + :rtype: ~protocol.models.RegistryStatistics or + ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.get_device_statistics.metadata["url"] + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + if custom_headers: + header_parameters.update(custom_headers) + + # Construct and send request + request = self._client.get(url, query_parameters, header_parameters) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("RegistryStatistics", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + get_device_statistics.metadata = {"url": "/statistics/devices"} + + def get_service_statistics(self, custom_headers=None, raw=False, **operation_config): + """Retrieves service statistics for this IoT hub’s identity registry. + + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: ServiceStatistics or ClientRawResponse if raw=true + :rtype: ~protocol.models.ServiceStatistics or + ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.get_service_statistics.metadata["url"] + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + if custom_headers: + header_parameters.update(custom_headers) + + # Construct and send request + request = self._client.get(url, query_parameters, header_parameters) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("ServiceStatistics", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + get_service_statistics.metadata = {"url": "/statistics/service"} + + def get_devices(self, top=None, custom_headers=None, raw=False, **operation_config): + """Get the identities of multiple devices from the IoT hub identity + registry. Not recommended. Use the IoT Hub query language to retrieve + device twin and device identity information. See + https://docs.microsoft.com/en-us/rest/api/iothub/service/queryiothub + and + https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-query-language + for more information. + + :param top: This parameter when specified, defines the maximum number + of device identities that are returned. Any value outside the range of + 1-1000 is considered to be 1000. + :type top: int + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: list or ClientRawResponse if raw=true + :rtype: list[~protocol.models.Device] or + ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.get_devices.metadata["url"] + + # Construct parameters + query_parameters = {} + if top is not None: + query_parameters["top"] = self._serialize.query("top", top, "int") + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + if custom_headers: + header_parameters.update(custom_headers) + + # Construct and send request + request = self._client.get(url, query_parameters, header_parameters) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("[Device]", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + get_devices.metadata = {"url": "/devices"} + + def bulk_device_crud(self, devices, custom_headers=None, raw=False, **operation_config): + """Create, update, or delete the identities of multiple devices from the + IoT hub identity registry. + + Create, update, or delete the identiies of multiple devices from the + IoT hub identity registry. A device identity can be specified only once + in the list. Different operations (create, update, delete) on different + devices are allowed. A maximum of 100 devices can be specified per + invocation. For large scale operations, consider using the import + feature using blob + storage(https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-identity-registry#import-and-export-device-identities). + + :param devices: + :type devices: list[~protocol.models.ExportImportDevice] + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: BulkRegistryOperationResult or ClientRawResponse if raw=true + :rtype: ~protocol.models.BulkRegistryOperationResult or + ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.bulk_device_crud.metadata["url"] + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + header_parameters["Content-Type"] = "application/json; charset=utf-8" + if custom_headers: + header_parameters.update(custom_headers) + + # Construct body + body_content = self._serialize.body(devices, "[ExportImportDevice]") + + # Construct and send request + request = self._client.post(url, query_parameters, header_parameters, body_content) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200, 400]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("BulkRegistryOperationResult", response) + if response.status_code == 400: + deserialized = self._deserialize("BulkRegistryOperationResult", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + bulk_device_crud.metadata = {"url": "/devices"} + + def query_iot_hub( + self, + query_specification, + x_ms_continuation=None, + x_ms_max_item_count=None, + custom_headers=None, + raw=False, + **operation_config + ): + """Query an IoT hub to retrieve information regarding device twins using a + SQL-like language. + + Query an IoT hub to retrieve information regarding device twins using a + SQL-like language. See + https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-query-language + for more information. Pagination of results is supported. This returns + information about device twins only. + + :param query_specification: + :type query_specification: ~protocol.models.QuerySpecification + :param x_ms_continuation: + :type x_ms_continuation: str + :param x_ms_max_item_count: + :type x_ms_max_item_count: str + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: list or ClientRawResponse if raw=true + :rtype: list[~protocol.models.Twin] or + ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.query_iot_hub.metadata["url"] + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + header_parameters["Content-Type"] = "application/json; charset=utf-8" + if custom_headers: + header_parameters.update(custom_headers) + if x_ms_continuation is not None: + header_parameters["x-ms-continuation"] = self._serialize.header( + "x_ms_continuation", x_ms_continuation, "str" + ) + if x_ms_max_item_count is not None: + header_parameters["x-ms-max-item-count"] = self._serialize.header( + "x_ms_max_item_count", x_ms_max_item_count, "str" + ) + + # Construct body + body_content = self._serialize.body(query_specification, "QuerySpecification") + + # Construct and send request + request = self._client.post(url, query_parameters, header_parameters, body_content) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + header_dict = {} + + if response.status_code == 200: + deserialized = self._deserialize("[Twin]", response) + header_dict = {"x-ms-item-type": "str", "x-ms-continuation": "str"} + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + client_raw_response.add_headers(header_dict) + return client_raw_response + + return deserialized + + query_iot_hub.metadata = {"url": "/devices/query"} + + def get_device(self, id, custom_headers=None, raw=False, **operation_config): + """Retrieve a device from the identity registry of an IoT hub. + + Retrieve a device from the identity registry of an IoT hub. + + :param id: Device ID. + :type id: str + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: Device or ClientRawResponse if raw=true + :rtype: ~protocol.models.Device or ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.get_device.metadata["url"] + path_format_arguments = {"id": self._serialize.url("id", id, "str")} + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + if custom_headers: + header_parameters.update(custom_headers) + + # Construct and send request + request = self._client.get(url, query_parameters, header_parameters) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("Device", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + get_device.metadata = {"url": "/devices/{id}"} + + def create_or_update_device( + self, id, device, if_match=None, custom_headers=None, raw=False, **operation_config + ): + """Create or update the identity of a device in the identity registry of + an IoT hub. + + Create or update the identity of a device in the identity registry of + an IoT hub. An ETag must not be specified for the create operation. An + ETag must be specified for the update operation. Note that generationId + and deviceId cannot be updated by the user. + + :param id: Device ID. + :type id: str + :param device: + :type device: ~protocol.models.Device + :param if_match: + :type if_match: str + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: Device or ClientRawResponse if raw=true + :rtype: ~protocol.models.Device or ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.create_or_update_device.metadata["url"] + path_format_arguments = {"id": self._serialize.url("id", id, "str")} + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + header_parameters["Content-Type"] = "application/json; charset=utf-8" + if custom_headers: + header_parameters.update(custom_headers) + if if_match is not None: + header_parameters["If-Match"] = self._serialize.header("if_match", if_match, "str") + + # Construct body + body_content = self._serialize.body(device, "Device") + + # Construct and send request + request = self._client.put(url, query_parameters, header_parameters, body_content) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("Device", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + create_or_update_device.metadata = {"url": "/devices/{id}"} + + def delete_device(self, id, if_match=None, custom_headers=None, raw=False, **operation_config): + """Delete the identity of a device from the identity registry of an IoT + hub. + + Delete the identity of a device from the identity registry of an IoT + hub. This request requires the If-Match header. The client may specify + the ETag for the device identity on the request in order to compare to + the ETag maintained by the service for the purpose of optimistic + concurrency. The delete operation is performed only if the ETag sent by + the client matches the value maintained by the server, indicating that + the device identity has not been modified since it was retrieved by the + client. To force an unconditional delete, set If-Match to the wildcard + character (*). + + :param id: Device ID. + :type id: str + :param if_match: + :type if_match: str + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: None or ClientRawResponse if raw=true + :rtype: None or ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.delete_device.metadata["url"] + path_format_arguments = {"id": self._serialize.url("id", id, "str")} + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + if custom_headers: + header_parameters.update(custom_headers) + if if_match is not None: + header_parameters["If-Match"] = self._serialize.header("if_match", if_match, "str") + + # Construct and send request + request = self._client.delete(url, query_parameters, header_parameters) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [204]: + raise HttpOperationError(self._deserialize, response) + + if raw: + client_raw_response = ClientRawResponse(None, response) + return client_raw_response + + delete_device.metadata = {"url": "/devices/{id}"} + + def purge_command_queue(self, id, custom_headers=None, raw=False, **operation_config): + """Deletes all the pending commands for this device from the IoT hub. + + Deletes all the pending commands for this device from the IoT hub. + + :param id: Device ID. + :type id: str + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: PurgeMessageQueueResult or ClientRawResponse if raw=true + :rtype: ~protocol.models.PurgeMessageQueueResult or + ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.purge_command_queue.metadata["url"] + path_format_arguments = {"id": self._serialize.url("id", id, "str")} + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + if custom_headers: + header_parameters.update(custom_headers) + + # Construct and send request + request = self._client.delete(url, query_parameters, header_parameters) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("PurgeMessageQueueResult", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + purge_command_queue.metadata = {"url": "/devices/{id}/commands"} + + def get_modules_on_device(self, id, custom_headers=None, raw=False, **operation_config): + """Retrieve all the module identities on the device. + + :param id: Device ID. + :type id: str + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: list or ClientRawResponse if raw=true + :rtype: list[~protocol.models.Module] or + ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.get_modules_on_device.metadata["url"] + path_format_arguments = {"id": self._serialize.url("id", id, "str")} + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + if custom_headers: + header_parameters.update(custom_headers) + + # Construct and send request + request = self._client.get(url, query_parameters, header_parameters) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("[Module]", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + get_modules_on_device.metadata = {"url": "/devices/{id}/modules"} + + def get_module(self, id, mid, custom_headers=None, raw=False, **operation_config): + """Retrieve the specified module identity on the device. + + :param id: Device ID. + :type id: str + :param mid: Module ID. + :type mid: str + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: Module or ClientRawResponse if raw=true + :rtype: ~protocol.models.Module or ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.get_module.metadata["url"] + path_format_arguments = { + "id": self._serialize.url("id", id, "str"), + "mid": self._serialize.url("mid", mid, "str"), + } + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + if custom_headers: + header_parameters.update(custom_headers) + + # Construct and send request + request = self._client.get(url, query_parameters, header_parameters) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("Module", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + get_module.metadata = {"url": "/devices/{id}/modules/{mid}"} + + def create_or_update_module( + self, id, mid, module, if_match=None, custom_headers=None, raw=False, **operation_config + ): + """Create or update the module identity for device in IoT hub. An ETag + must not be specified for the create operation. An ETag must be + specified for the update operation. Note that moduleId and generation + cannot be updated by the user. + + :param id: Device ID. + :type id: str + :param mid: Module ID. + :type mid: str + :param module: + :type module: ~protocol.models.Module + :param if_match: + :type if_match: str + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: Module or ClientRawResponse if raw=true + :rtype: ~protocol.models.Module or ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.create_or_update_module.metadata["url"] + path_format_arguments = { + "id": self._serialize.url("id", id, "str"), + "mid": self._serialize.url("mid", mid, "str"), + } + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + header_parameters["Content-Type"] = "application/json; charset=utf-8" + if custom_headers: + header_parameters.update(custom_headers) + if if_match is not None: + header_parameters["If-Match"] = self._serialize.header("if_match", if_match, "str") + + # Construct body + body_content = self._serialize.body(module, "Module") + + # Construct and send request + request = self._client.put(url, query_parameters, header_parameters, body_content) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200, 201]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("Module", response) + if response.status_code == 201: + deserialized = self._deserialize("Module", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + create_or_update_module.metadata = {"url": "/devices/{id}/modules/{mid}"} + + def delete_module( + self, id, mid, if_match=None, custom_headers=None, raw=False, **operation_config + ): + """Delete the module identity for device of an IoT hub. This request + requires the If-Match header. The client may specify the ETag for the + device identity on the request in order to compare to the ETag + maintained by the service for the purpose of optimistic concurrency. + The delete operation is performed only if the ETag sent by the client + matches the value maintained by the server, indicating that the device + identity has not been modified since it was retrieved by the client. To + force an unconditional delete, set If-Match to the wildcard character + (*). + + :param id: Device ID. + :type id: str + :param mid: Module ID. + :type mid: str + :param if_match: + :type if_match: str + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: None or ClientRawResponse if raw=true + :rtype: None or ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.delete_module.metadata["url"] + path_format_arguments = { + "id": self._serialize.url("id", id, "str"), + "mid": self._serialize.url("mid", mid, "str"), + } + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + if custom_headers: + header_parameters.update(custom_headers) + if if_match is not None: + header_parameters["If-Match"] = self._serialize.header("if_match", if_match, "str") + + # Construct and send request + request = self._client.delete(url, query_parameters, header_parameters) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [204]: + raise HttpOperationError(self._deserialize, response) + + if raw: + client_raw_response = ClientRawResponse(None, response) + return client_raw_response + + delete_module.metadata = {"url": "/devices/{id}/modules/{mid}"} diff --git a/azure-iot-hub/azure/iot/hub/protocol/operations/twin_operations.py b/azure-iot-hub/azure/iot/hub/protocol/operations/twin_operations.py new file mode 100644 index 000000000..07d260e7a --- /dev/null +++ b/azure-iot-hub/azure/iot/hub/protocol/operations/twin_operations.py @@ -0,0 +1,458 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is +# regenerated. +# -------------------------------------------------------------------------- + +from msrest.pipeline import ClientRawResponse +from msrest.exceptions import HttpOperationError + +from .. import models + + +class TwinOperations(object): + """TwinOperations operations. + + :param client: Client for service requests. + :param config: Configuration of service client. + :param serializer: An object model serializer. + :param deserializer: An object model deserializer. + :ivar api_version: Version of the Api. Constant value: '2020-03-13'. + """ + + models = models + + def __init__(self, client, config, serializer, deserializer): + + self._client = client + self._serialize = serializer + self._deserialize = deserializer + + self.config = config + self.api_version = "2020-03-13" + + def get_device_twin(self, id, custom_headers=None, raw=False, **operation_config): + """Gets a device twin. + + Gets a device twin. See + https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-device-twins + for more information. + + :param id: Device ID. + :type id: str + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: Twin or ClientRawResponse if raw=true + :rtype: ~protocol.models.Twin or ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.get_device_twin.metadata["url"] + path_format_arguments = {"id": self._serialize.url("id", id, "str")} + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + if custom_headers: + header_parameters.update(custom_headers) + + # Construct and send request + request = self._client.get(url, query_parameters, header_parameters) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("Twin", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + get_device_twin.metadata = {"url": "/twins/{id}"} + + def replace_device_twin( + self, + id, + device_twin_info, + if_match=None, + custom_headers=None, + raw=False, + **operation_config + ): + """Replaces tags and desired properties of a device twin. + + Replaces a device twin. See + https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-device-twins + for more information. + + :param id: Device ID. + :type id: str + :param device_twin_info: Device twin info + :type device_twin_info: ~protocol.models.Twin + :param if_match: + :type if_match: str + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: Twin or ClientRawResponse if raw=true + :rtype: ~protocol.models.Twin or ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.replace_device_twin.metadata["url"] + path_format_arguments = {"id": self._serialize.url("id", id, "str")} + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + header_parameters["Content-Type"] = "application/json; charset=utf-8" + if custom_headers: + header_parameters.update(custom_headers) + if if_match is not None: + header_parameters["If-Match"] = self._serialize.header("if_match", if_match, "str") + + # Construct body + body_content = self._serialize.body(device_twin_info, "Twin") + + # Construct and send request + request = self._client.put(url, query_parameters, header_parameters, body_content) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("Twin", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + replace_device_twin.metadata = {"url": "/twins/{id}"} + + def update_device_twin( + self, + id, + device_twin_info, + if_match=None, + custom_headers=None, + raw=False, + **operation_config + ): + """Updates tags and desired properties of a device twin. + + Updates a device twin. See + https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-device-twins + for more information. + + :param id: Device ID. + :type id: str + :param device_twin_info: Device twin info + :type device_twin_info: ~protocol.models.Twin + :param if_match: + :type if_match: str + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: Twin or ClientRawResponse if raw=true + :rtype: ~protocol.models.Twin or ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.update_device_twin.metadata["url"] + path_format_arguments = {"id": self._serialize.url("id", id, "str")} + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + header_parameters["Content-Type"] = "application/json; charset=utf-8" + if custom_headers: + header_parameters.update(custom_headers) + if if_match is not None: + header_parameters["If-Match"] = self._serialize.header("if_match", if_match, "str") + + # Construct body + body_content = self._serialize.body(device_twin_info, "Twin") + + # Construct and send request + request = self._client.patch(url, query_parameters, header_parameters, body_content) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("Twin", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + update_device_twin.metadata = {"url": "/twins/{id}"} + + def get_module_twin(self, id, mid, custom_headers=None, raw=False, **operation_config): + """Gets a module twin. + + Gets a module twin. See + https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-device-twins + for more information. + + :param id: Device ID. + :type id: str + :param mid: Module ID. + :type mid: str + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: Twin or ClientRawResponse if raw=true + :rtype: ~protocol.models.Twin or ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.get_module_twin.metadata["url"] + path_format_arguments = { + "id": self._serialize.url("id", id, "str"), + "mid": self._serialize.url("mid", mid, "str"), + } + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + if custom_headers: + header_parameters.update(custom_headers) + + # Construct and send request + request = self._client.get(url, query_parameters, header_parameters) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("Twin", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + get_module_twin.metadata = {"url": "/twins/{id}/modules/{mid}"} + + def replace_module_twin( + self, + id, + mid, + device_twin_info, + if_match=None, + custom_headers=None, + raw=False, + **operation_config + ): + """Replaces tags and desired properties of a module twin. + + Replaces a module twin. See + https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-device-twins + for more information. + + :param id: Device ID. + :type id: str + :param mid: Module ID. + :type mid: str + :param device_twin_info: Device twin info + :type device_twin_info: ~protocol.models.Twin + :param if_match: + :type if_match: str + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: Twin or ClientRawResponse if raw=true + :rtype: ~protocol.models.Twin or ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.replace_module_twin.metadata["url"] + path_format_arguments = { + "id": self._serialize.url("id", id, "str"), + "mid": self._serialize.url("mid", mid, "str"), + } + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + header_parameters["Content-Type"] = "application/json; charset=utf-8" + if custom_headers: + header_parameters.update(custom_headers) + if if_match is not None: + header_parameters["If-Match"] = self._serialize.header("if_match", if_match, "str") + + # Construct body + body_content = self._serialize.body(device_twin_info, "Twin") + + # Construct and send request + request = self._client.put(url, query_parameters, header_parameters, body_content) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("Twin", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + replace_module_twin.metadata = {"url": "/twins/{id}/modules/{mid}"} + + def update_module_twin( + self, + id, + mid, + device_twin_info, + if_match=None, + custom_headers=None, + raw=False, + **operation_config + ): + """Updates tags and desired properties of a module twin. + + Updates a module twin. See + https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-device-twins + for more information. + + :param id: Device ID. + :type id: str + :param mid: Module ID. + :type mid: str + :param device_twin_info: Device twin information + :type device_twin_info: ~protocol.models.Twin + :param if_match: + :type if_match: str + :param dict custom_headers: headers that will be added to the request + :param bool raw: returns the direct response alongside the + deserialized response + :param operation_config: :ref:`Operation configuration + overrides`. + :return: Twin or ClientRawResponse if raw=true + :rtype: ~protocol.models.Twin or ~msrest.pipeline.ClientRawResponse + :raises: + :class:`HttpOperationError` + """ + # Construct URL + url = self.update_module_twin.metadata["url"] + path_format_arguments = { + "id": self._serialize.url("id", id, "str"), + "mid": self._serialize.url("mid", mid, "str"), + } + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters["api-version"] = self._serialize.query( + "self.api_version", self.api_version, "str" + ) + + # Construct headers + header_parameters = {} + header_parameters["Accept"] = "application/json" + header_parameters["Content-Type"] = "application/json; charset=utf-8" + if custom_headers: + header_parameters.update(custom_headers) + if if_match is not None: + header_parameters["If-Match"] = self._serialize.header("if_match", if_match, "str") + + # Construct body + body_content = self._serialize.body(device_twin_info, "Twin") + + # Construct and send request + request = self._client.patch(url, query_parameters, header_parameters, body_content) + response = self._client.send(request, stream=False, **operation_config) + + if response.status_code not in [200]: + raise HttpOperationError(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize("Twin", response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + + update_module_twin.metadata = {"url": "/twins/{id}/modules/{mid}"} diff --git a/azure-iot-hub/azure/iot/hub/protocol/version.py b/azure-iot-hub/azure/iot/hub/protocol/version.py index 641d943e1..89112d98b 100644 --- a/azure-iot-hub/azure/iot/hub/protocol/version.py +++ b/azure-iot-hub/azure/iot/hub/protocol/version.py @@ -5,4 +5,4 @@ # regenerated. # -------------------------------------------------------------------------- -VERSION = "2019-07-01-preview" +VERSION = "2020-03-13" diff --git a/azure-iot-hub/samples/iothub_configuration_manager_sample.py b/azure-iot-hub/samples/iothub_configuration_manager_sample.py new file mode 100644 index 000000000..93ee4cb14 --- /dev/null +++ b/azure-iot-hub/samples/iothub_configuration_manager_sample.py @@ -0,0 +1,75 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import sys +import os +from azure.iot.hub import IoTHubConfigurationManager +from azure.iot.hub.models import Configuration, ConfigurationContent, ConfigurationMetrics + + +iothub_connection_str = os.getenv("IOTHUB_CONNECTION_STRING") + + +def print_configuration(title, config): + print() + print(title) + print("Configuration:") + print(" {}".format(config)) + print("Configuration - content:") + print(" {}".format(config.content)) + print("Configuration - metrics:") + print(" {}".format(config.metrics)) + + +def create_configuration(config_id): + config = Configuration() + config.id = config_id + + content = ConfigurationContent( + device_content={"properties.desired.chiller-water": {"temperature: 68, pressure:28"}} + ) + config.content = content + + metrics = ConfigurationMetrics( + queries={ + "waterSettingPending": "SELECT deviceId FROM devices WHERE properties.reported.chillerWaterSettings.status='pending'" + } + ) + config.metrics = metrics + + return config + + +try: + # Create IoTHubConfigurationManager + iothub_configuration = IoTHubConfigurationManager(iothub_connection_str) + + # Create configuration + config_id = "sample_config" + sample_configuration = create_configuration(config_id) + print_configuration("Sample configuration", sample_configuration) + + created_config = iothub_configuration.create_configuration(sample_configuration) + print_configuration("Created configuration", created_config) + + # Get configuration + get_config = iothub_configuration.get_configuration(config_id) + print_configuration("Get configuration", get_config) + + # Delete configuration + iothub_configuration.delete_configuration(config_id) + + # Get all configurations + configurations = iothub_configuration.get_configurations() + if configurations: + print_configuration("Get all configurations", configurations[0]) + else: + print("No configuration found") + +except Exception as ex: + print("Unexpected error {0}".format(ex)) +except KeyboardInterrupt: + print("iothub_registry_manager_sample stopped") diff --git a/azure-iot-hub/samples/iothub_job_manager_sample.py b/azure-iot-hub/samples/iothub_job_manager_sample.py new file mode 100644 index 000000000..e684b2582 --- /dev/null +++ b/azure-iot-hub/samples/iothub_job_manager_sample.py @@ -0,0 +1,119 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import sys +import os +from azure.iot.hub import IoTHubJobManager +from azure.iot.hub.models import JobProperties, JobRequest + + +iothub_connection_str = os.getenv("IOTHUB_CONNECTION_STRING") +output_container_uri = os.getenv("JOB_EXPORT_IMPORT_OUTPUT_URI") + + +def create_export_import_job_properties(): + job_properties = JobProperties() + job_properties.authentication_type = "keyBased" + job_properties.type = "export" + job_properties.output_blob_container_uri = output_container_uri + return job_properties + + +def print_export_import_job(title, job): + print() + print(title) + print(" job_id: {}".format(job.job_id)) + print(" type: {}".format(job.type)) + print(" status: {}".format(job.status)) + print(" start_time_utc: {}".format(job.start_time_utc)) + print(" end_time_utc: {}".format(job.end_time_utc)) + print(" progress: {}".format(job.progress)) + print(" input_blob_container_uri: {}".format(job.input_blob_container_uri)) + print(" input_blob_name: {}".format(job.input_blob_name)) + print(" output_blob_container_uri: {}".format(job.output_blob_container_uri)) + print(" output_blob_name: {}".format(job.output_blob_name)) + print(" exclude_keys_in_export: {}".format(job.exclude_keys_in_export)) + print(" storage_authentication_type: {}".format(job.storage_authentication_type)) + print(" failure_reason: {}".format(job.failure_reason)) + + +def print_export_import_jobs(title, export_import_jobs): + print("") + x = 1 + if len([export_import_jobs]) > 0: + for j in range(len(export_import_jobs)): + print_export_import_job("{0}: {1}".format(title, x), export_import_jobs[j]) + x += 1 + else: + print("No item found") + + +def create_job_request(): + job = JobRequest() + job.job_id = "sample_cloud_to_device_method" + job.type = "cloudToDeviceMethod" + job.start_time = "" + job.max_execution_time_in_seconds = 60 + job.update_twin = "" + job.query_condition = "" + return job + + +def print_job_response(title, job): + print() + print(title) + print(" job_id: {}".format(job.job_id)) + print(" type: {}".format(job.type)) + print(" start_time: {}".format(job.start_time)) + print(" max_execution_time_in_seconds: {}".format(job.max_execution_time_in_seconds)) + print(" update_twin: {}".format(job.update_twin)) + print(" query_condition: {}".format(job.query_condition)) + + +try: + # Create IoTHubJobManager + iothub_job_manager = IoTHubJobManager(iothub_connection_str) + + # Get all export/import jobs + export_import_jobs = iothub_job_manager.get_import_export_jobs() + if export_import_jobs: + print_export_import_jobs("Get all export/import jobs", export_import_jobs) + else: + print("No export/import job found") + + # Create export/import job + new_export_import_job = iothub_job_manager.create_import_export_job( + create_export_import_job_properties() + ) + print_export_import_job("Create export/import job result: ", new_export_import_job) + + # Get export/import job + get_export_import_job = iothub_job_manager.get_import_export_job(new_export_import_job.job_id) + print_export_import_job("Get export/import job result: ", get_export_import_job) + + # Cancel export/import job + cancel_export_import_job = iothub_job_manager.cancel_import_export_job( + get_export_import_job.job_id + ) + print(cancel_export_import_job) + + # Create job + job_request = create_job_request() + new_job_response = iothub_job_manager.create_job(job_request.job_id, job_request) + print_job_response("Create job response: ", new_job_response) + + # Get job + get_job_response = iothub_job_manager.get_job(new_job_response.job_id) + print_job_response("Get job response: ", get_job_response) + + # Cancel job + cancel_job_response = iothub_job_manager.cancel_job(get_job_response.job_id) + print_job_response("Cancel job response: ", cancel_job_response) + +except Exception as ex: + print("Unexpected error {0}".format(ex)) +except KeyboardInterrupt: + print("iothub_registry_manager_sample stopped") diff --git a/azure-iot-hub/samples/iothub_registry_manager_bulk_create_sample.py b/azure-iot-hub/samples/iothub_registry_manager_bulk_create_sample.py new file mode 100644 index 000000000..7f16bb305 --- /dev/null +++ b/azure-iot-hub/samples/iothub_registry_manager_bulk_create_sample.py @@ -0,0 +1,84 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import sys +import os +from azure.iot.hub import IoTHubRegistryManager +from azure.iot.hub.models import ExportImportDevice, AuthenticationMechanism, SymmetricKey + +iothub_connection_str = os.getenv("IOTHUB_CONNECTION_STRING") + + +def print_device_info(title, iothub_device): + print(title + ":") + print("device_id = {0}".format(iothub_device.device_id)) + print("authentication.type = {0}".format(iothub_device.authentication.type)) + print("authentication.symmetric_key = {0}".format(iothub_device.authentication.symmetric_key)) + print( + "authentication.x509_thumbprint = {0}".format(iothub_device.authentication.x509_thumbprint) + ) + print("connection_state = {0}".format(iothub_device.connection_state)) + print( + "connection_state_updated_tTime = {0}".format(iothub_device.connection_state_updated_time) + ) + print( + "cloud_to_device_message_count = {0}".format(iothub_device.cloud_to_device_message_count) + ) + print("device_scope = {0}".format(iothub_device.device_scope)) + print("etag = {0}".format(iothub_device.etag)) + print("generation_id = {0}".format(iothub_device.generation_id)) + print("last_activity_time = {0}".format(iothub_device.last_activity_time)) + print("status = {0}".format(iothub_device.status)) + print("status_reason = {0}".format(iothub_device.status_reason)) + print("status_updated_time = {0}".format(iothub_device.status_updated_time)) + print("") + + +try: + # Create IoTHubRegistryManager + iothub_registry_manager = IoTHubRegistryManager(iothub_connection_str) + + primary_key1 = "aaabbbcccdddeeefffggghhhiiijjjkkklllmmmnnnoo" + secondary_key1 = "111222333444555666777888999000aaabbbcccdddee" + symmetric_key1 = SymmetricKey(primary_key=primary_key1, secondary_key=secondary_key1) + authentication1 = AuthenticationMechanism(type="sas", symmetric_key=symmetric_key1) + device1 = ExportImportDevice(id="BulkDevice1", status="enabled", authentication=authentication1) + + primary_key2 = "cccbbbaaadddeeefffggghhhiiijjjkkklllmmmnnnoo" + secondary_key2 = "333222111444555666777888999000aaabbbcccdddee" + symmetric_key2 = SymmetricKey(primary_key=primary_key2, secondary_key=secondary_key2) + authentication2 = AuthenticationMechanism(type="sas", symmetric_key=symmetric_key2) + device2 = ExportImportDevice(id="BulkDevice2", status="enabled", authentication=authentication2) + + # Create devices + device1.import_mode = "create" + device2.import_mode = "create" + device_list = [device1, device2] + + iothub_registry_manager.bulk_create_or_update_devices(device_list) + + # Get devices (max. 1000 with get_devices API) + max_number_of_devices = 10 + devices = iothub_registry_manager.get_devices(max_number_of_devices) + if devices: + x = 0 + for d in devices: + print_device_info("Get devices {0}".format(x), d) + x += 1 + else: + print("No device found") + + # Delete devices + device1.import_mode = "delete" + device2.import_mode = "delete" + device_list = [device1, device2] + + iothub_registry_manager.bulk_create_or_update_devices(device_list) + +except Exception as ex: + print("Unexpected error {0}".format(ex)) +except KeyboardInterrupt: + print("iothub_registry_manager_sample stopped") diff --git a/azure-iot-hub/samples/iothub_registry_manager_c2d_sample.py b/azure-iot-hub/samples/iothub_registry_manager_c2d_sample.py new file mode 100644 index 000000000..0cb20ee7c --- /dev/null +++ b/azure-iot-hub/samples/iothub_registry_manager_c2d_sample.py @@ -0,0 +1,26 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import sys +import os +from azure.iot.hub import IoTHubRegistryManager + +connection_str = os.getenv("IOTHUB_CONNECTION_STRING") +device_id = os.getenv("IOTHUB_DEVICE_ID") +send_message = "C2D message to be send to device" + +try: + # Create IoTHubRegistryManager + registry_manager = IoTHubRegistryManager(connection_str) + print("Conn String: {0}".format(connection_str)) + + # Send Message To Device + registry_manager.send_c2d_message(device_id, send_message) + +except Exception as ex: + print("Unexpected error {0}".format(ex)) +except KeyboardInterrupt: + print("iothub_statistics stopped") diff --git a/azure-iot-hub/samples/iothub_registry_manager_method_sample.py b/azure-iot-hub/samples/iothub_registry_manager_method_sample.py new file mode 100644 index 000000000..87609a70e --- /dev/null +++ b/azure-iot-hub/samples/iothub_registry_manager_method_sample.py @@ -0,0 +1,29 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import sys +import os +from azure.iot.hub import IoTHubRegistryManager +from azure.iot.hub.models import CloudToDeviceMethod + + +iothub_connection_str = os.getenv("IOTHUB_CONNECTION_STRING") +device_id = os.getenv("IOTHUB_DEVICE_ID") +method_name = "lockDoor" +method_payload = "now" + + +try: + # Create IoTHubRegistryManager + registry_manager = IoTHubRegistryManager(iothub_connection_str) + + deviceMethod = CloudToDeviceMethod(method_name=method_name, payload=method_payload) + registry_manager.invoke_device_method(device_id, deviceMethod) + +except Exception as ex: + print("Unexpected error {0}".format(ex)) +except KeyboardInterrupt: + print("iothub_registry_manager_sample stopped") diff --git a/azure-iot-hub/samples/iothub_registry_manager_module_method_sample.py b/azure-iot-hub/samples/iothub_registry_manager_module_method_sample.py new file mode 100644 index 000000000..5627e6efe --- /dev/null +++ b/azure-iot-hub/samples/iothub_registry_manager_module_method_sample.py @@ -0,0 +1,40 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import sys +import os +from azure.iot.hub import IoTHubRegistryManager +from azure.iot.hub.models import CloudToDeviceMethod + +iothub_connection_str = os.getenv("IOTHUB_CONNECTION_STRING") +device_id = os.getenv("IOTHUB_DEVICE_ID") +module_id = os.getenv("IOTHUB_MODULE_ID") +method_name = "lockDoor" +method_payload = "now" + +try: + # RegistryManager + iothub_registry_manager = IoTHubRegistryManager(iothub_connection_str) + + # Create Module + primary_key = "aaabbbcccdddeeefffggghhhiiijjjkkklllmmmnnnoo" + secondary_key = "111222333444555666777888999000aaabbbcccdddee" + managed_by = "" + new_module = iothub_registry_manager.create_module_with_sas( + device_id, module_id, managed_by, primary_key, secondary_key + ) + + deviceMethod = CloudToDeviceMethod(method_name=method_name, payload=method_payload) + iothub_registry_manager.invoke_device_module_method(device_id, module_id, deviceMethod) + + # Delete Module + iothub_registry_manager.delete_module(device_id, module_id) + print("Deleted Module {0}".format(module_id)) + +except Exception as ex: + print("Unexpected error {0}".format(ex)) +except KeyboardInterrupt: + print("IoTHubRegistryManager sample stopped") diff --git a/azure-iot-hub/samples/iothub_registry_manager_module_sample.py b/azure-iot-hub/samples/iothub_registry_manager_module_sample.py new file mode 100644 index 000000000..fcf2a42ca --- /dev/null +++ b/azure-iot-hub/samples/iothub_registry_manager_module_sample.py @@ -0,0 +1,108 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import sys +import os +from azure.iot.hub import IoTHubRegistryManager +from azure.iot.hub.models import Twin, TwinProperties + +iothub_connection_str = os.getenv("IOTHUB_CONNECTION_STRING") +device_id = os.getenv("IOTHUB_DEVICE_ID") +module_id = os.getenv("IOTHUB_MODULE_ID") + + +def print_module_info(title, iothub_module): + print(title + ":") + print("iothubModule.device_id = {0}".format(iothub_module.device_id)) + print("iothubModule.module_id = {0}".format(iothub_module.module_id)) + print("iothubModule.managed_by = {0}".format(iothub_module.managed_by)) + print("iothubModule.generation_id = {0}".format(iothub_module.generation_id)) + print("iothubModule.etag = {0}".format(iothub_module.etag)) + print( + "iothubModule.connection_state = {0}".format(iothub_module.connection_state) + ) + print( + "iothubModule.connection_state_updated_time = {0}".format( + iothub_module.connection_state_updated_time + ) + ) + print( + "iothubModule.last_activity_time = {0}".format(iothub_module.last_activity_time) + ) + print( + "iothubModule.cloud_to_device_message_count = {0}".format( + iothub_module.cloud_to_device_message_count + ) + ) + print("iothubModule.authentication = {0}".format(iothub_module.authentication)) + print("") + + +try: + # RegistryManager + iothub_registry_manager = IoTHubRegistryManager(iothub_connection_str) + + # Create Module + primary_key = "aaabbbcccdddeeefffggghhhiiijjjkkklllmmmnnnoo" + secondary_key = "111222333444555666777888999000aaabbbcccdddee" + managed_by = "" + new_module = iothub_registry_manager.create_module_with_sas( + device_id, module_id, managed_by, primary_key, secondary_key + ) + print_module_info("Create Module", new_module) + + # Get Module + iothub_module = iothub_registry_manager.get_module(device_id, module_id) + print_module_info("Get Module", iothub_module) + + # Update Module + primary_key = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + secondary_key = "yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy" + managed_by = "testManagedBy" + updated_module = iothub_registry_manager.update_module_with_sas( + device_id, module_id, managed_by, iothub_module.etag, primary_key, secondary_key + ) + print_module_info("Update Module", updated_module) + + # Get Module Twin + module_twin = iothub_registry_manager.get_module_twin(device_id, module_id) + print(module_twin) + + # # Replace Twin + new_twin = Twin() + new_twin = module_twin + new_twin.properties = TwinProperties(desired={"telemetryInterval": 9000}) + print(new_twin) + print("") + + replaced_module_twin = iothub_registry_manager.replace_module_twin( + device_id, module_id, new_twin + ) + print(replaced_module_twin) + print("") + + # Update twin + twin_patch = Twin() + twin_patch.properties = TwinProperties(desired={"telemetryInterval": 3000}) + updated_module_twin = iothub_registry_manager.update_module_twin( + device_id, module_id, twin_patch, module_twin.etag + ) + print(updated_module_twin) + print("") + + # Get all modules on the device + all_modules = iothub_registry_manager.get_modules(device_id) + for module in all_modules: + print_module_info("", module) + + # Delete Module + iothub_registry_manager.delete_module(device_id, module_id) + print("Deleted Module {0}".format(module_id)) + +except Exception as ex: + print("Unexpected error {0}".format(ex)) +except KeyboardInterrupt: + print("IoTHubRegistryManager sample stopped") diff --git a/azure-iot-hub/samples/iothub_registry_manager_query_sample.py b/azure-iot-hub/samples/iothub_registry_manager_query_sample.py new file mode 100644 index 000000000..3d2aedd06 --- /dev/null +++ b/azure-iot-hub/samples/iothub_registry_manager_query_sample.py @@ -0,0 +1,80 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import sys +import os +from azure.iot.hub import IoTHubRegistryManager +from azure.iot.hub.models import QuerySpecification + +iothub_connection_str = os.getenv("IOTHUB_CONNECTION_STRING") + + +def print_twin(title, iothub_device): + print(title + ":") + print("device_id = {0}".format(iothub_device.device_id)) + print("module_id = {0}".format(iothub_device.module_id)) + print("authentication_type = {0}".format(iothub_device.authentication_type)) + print("x509_thumbprint = {0}".format(iothub_device.x509_thumbprint)) + print("etag = {0}".format(iothub_device.etag)) + print("device_etag = {0}".format(iothub_device.device_etag)) + print("tags = {0}".format(iothub_device.tags)) + print("version = {0}".format(iothub_device.version)) + + print("status = {0}".format(iothub_device.status)) + print("status_reason = {0}".format(iothub_device.status_reason)) + print("status_update_time = {0}".format(iothub_device.status_update_time)) + print("connection_state = {0}".format(iothub_device.connection_state)) + print("last_activity_time = {0}".format(iothub_device.last_activity_time)) + print( + "cloud_to_device_message_count = {0}".format(iothub_device.cloud_to_device_message_count) + ) + print("device_scope = {0}".format(iothub_device.device_scope)) + + print("properties = {0}".format(iothub_device.properties)) + print("additional_properties = {0}".format(iothub_device.additional_properties)) + print("") + + +def print_query_result(title, query_result): + print("") + print("Type: {0}".format(query_result.type)) + print("Continuation token: {0}".format(query_result.continuation_token)) + if query_result.items: + x = 1 + for d in query_result.items: + print_twin("{0}: {1}".format(title, x), d) + x += 1 + else: + print("No item found") + + +try: + # Create IoTHubRegistryManager + iothub_registry_manager = IoTHubRegistryManager(iothub_connection_str) + + query_specification = QuerySpecification(query="SELECT * FROM devices") + + # Get specified number of devices (in this case 4) + query_result0 = iothub_registry_manager.query_iot_hub(query_specification, None, 4) + print_query_result("Query 4 device twins", query_result0) + + # Get all device twins using query + query_result1 = iothub_registry_manager.query_iot_hub(query_specification) + print_query_result("Query all device twins", query_result1) + + # Paging... Get more devices (over 1000) + continuation_token = query_result1.continuation_token + if continuation_token: + query_result2 = iothub_registry_manager.query_iot_hub( + query_specification, continuation_token + ) + print_query_result("Query all device twins - continued", query_result2) + + +except Exception as ex: + print("Unexpected error {0}".format(ex)) +except KeyboardInterrupt: + print("iothub_registry_manager_sample stopped") diff --git a/azure-iot-hub/samples/iothub_registry_manager_sample.py b/azure-iot-hub/samples/iothub_registry_manager_sample.py index 628199d70..1571dacb1 100644 --- a/azure-iot-hub/samples/iothub_registry_manager_sample.py +++ b/azure-iot-hub/samples/iothub_registry_manager_sample.py @@ -7,9 +7,10 @@ import sys import os from azure.iot.hub import IoTHubRegistryManager +from azure.iot.hub.models import Twin, TwinProperties -connection_str = os.getenv("IOTHUB_CONNECTION_STRING") -device_id = "test_device" +iothub_connection_str = os.getenv("IOTHUB_CONNECTION_STRING") +device_id = os.getenv("IOTHUB_NEW_DEVICE_ID") def print_device_info(title, iothub_device): @@ -40,46 +41,80 @@ def print_device_info(title, iothub_device): # This sample creates and uses device with SAS authentication # For other authentication types use the appropriate create and update APIs: # X509: -# new_device = registry_manager.create_device_with_x509(device_id, primary_thumbprint, secondary_thumbprint, status) -# device_updated = registry_manager.update_device_with_X509(device_id, etag, primary_thumbprint, secondary_thumbprint, status) +# new_device = iothub_registry_manager.create_device_with_x509(device_id, primary_thumbprint, secondary_thumbprint, status) +# device_updated = iothub_registry_manager.update_device_with_X509(device_id, etag, primary_thumbprint, secondary_thumbprint, status) # Certificate authority: -# new_device = registry_manager.create_device_with_certificate_authority(device_id, status) -# device_updated = registry_manager.update_device_with_certificate_authority(self, device_id, etag, status): +# new_device = iothub_registry_manager.create_device_with_certificate_authority(device_id, status) +# device_updated = iothub_registry_manager.update_device_with_certificate_authority(self, device_id, etag, status): try: # Create IoTHubRegistryManager - registry_manager = IoTHubRegistryManager(connection_str) + iothub_registry_manager = IoTHubRegistryManager(iothub_connection_str) # Create a device primary_key = "aaabbbcccdddeeefffggghhhiiijjjkkklllmmmnnnoo" secondary_key = "111222333444555666777888999000aaabbbcccdddee" device_state = "enabled" - new_device = registry_manager.create_device_with_sas( + new_device = iothub_registry_manager.create_device_with_sas( device_id, primary_key, secondary_key, device_state ) print_device_info("create_device", new_device) # Get device information - device = registry_manager.get_device(device_id) + device = iothub_registry_manager.get_device(device_id) print_device_info("get_device", device) # Update device information primary_key = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" secondary_key = "yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy" device_state = "disabled" - device_updated = registry_manager.update_device_with_sas( - new_device.device_id, new_device.etag, primary_key, secondary_key, device_state + device_updated = iothub_registry_manager.update_device_with_sas( + device_id, device.etag, primary_key, secondary_key, device_state ) print_device_info("update_device", device_updated) + # Get device twin + twin = iothub_registry_manager.get_twin(device_id) + print(twin) + print("") + + # # Replace twin + new_twin = Twin() + new_twin = twin + new_twin.properties = TwinProperties(desired={"telemetryInterval": 9000}) + print(new_twin) + print("") + + replaced_twin = iothub_registry_manager.replace_twin(device_id, new_twin) + print(replaced_twin) + print("") + + # Update twin + twin_patch = Twin() + twin_patch.properties = TwinProperties(desired={"telemetryInterval": 3000}) + updated_twin = iothub_registry_manager.update_twin(device_id, twin_patch, twin.etag) + print(updated_twin) + print("") + + # Get devices + max_number_of_devices = 10 + devices = iothub_registry_manager.get_devices(max_number_of_devices) + if devices: + x = 0 + for d in devices: + print_device_info("Get devices {0}".format(x), d) + x += 1 + else: + print("No device found") + # Delete the device - registry_manager.delete_device(device_id) + iothub_registry_manager.delete_device(device_id) print("GetServiceStatistics") - registry_statistics = registry_manager.get_service_statistics() + registry_statistics = iothub_registry_manager.get_service_statistics() print(registry_statistics) print("GetDeviceRegistryStatistics") - registry_statistics = registry_manager.get_device_registry_statistics() + registry_statistics = iothub_registry_manager.get_device_registry_statistics() print(registry_statistics) except Exception as ex: diff --git a/azure-iot-hub/samples/iothub_registry_manager_statistics_sample.py b/azure-iot-hub/samples/iothub_registry_manager_statistics_sample.py new file mode 100644 index 000000000..33c09b55d --- /dev/null +++ b/azure-iot-hub/samples/iothub_registry_manager_statistics_sample.py @@ -0,0 +1,51 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import sys +import os +from azure.iot.hub import IoTHubRegistryManager + +connection_str = os.getenv("IOTHUB_CONNECTION_STRING") + +try: + # Create IoTHubRegistryManager + registry_manager = IoTHubRegistryManager(connection_str) + + print("Conn String: {0}".format(connection_str)) + + # GetStatistics + service_statistics = registry_manager.get_service_statistics() + print("Service Statistics:") + print( + "Total connected device count : {0}".format( + service_statistics.connected_device_count + ) + ) + print("") + + registry_statistics = registry_manager.get_device_registry_statistics() + print("Device Registry Statistics:") + print( + "Total device count : {0}".format( + registry_statistics.total_device_count + ) + ) + print( + "Enabled device count : {0}".format( + registry_statistics.enabled_device_count + ) + ) + print( + "Disabled device count : {0}".format( + registry_statistics.disabled_device_count + ) + ) + print("") + +except Exception as ex: + print("Unexpected error {0}".format(ex)) +except KeyboardInterrupt: + print("iothub_statistics stopped") diff --git a/azure-iot-hub/service.json b/azure-iot-hub/service.json index 6ca2e7152..70aa6a0b8 100644 --- a/azure-iot-hub/service.json +++ b/azure-iot-hub/service.json @@ -1,8 +1,8 @@ { "swagger": "2.0", "info": { - "version": "2019-07-01-preview", - "title": "IotHub Gateway Service APIs - 2019-07-01-preview" + "version": "2020-03-13", + "title": "IotHub Gateway Service APIs" }, "host": "fully-qualified-iothubname.azure-devices.net", "schemes": [ @@ -12,7 +12,7 @@ "/configurations/{id}": { "get": { "summary": "Retrieve a configuration for Iot Hub devices and modules by it identifier.", - "operationId": "Service_GetConfiguration", + "operationId": "Configuration_Get", "consumes": [], "produces": [ "application/json" @@ -39,7 +39,7 @@ }, "put": { "summary": "Create or update the configuration for devices or modules of an IoT hub. An ETag must not be specified for the create operation. An ETag must be specified for the update operation. Note that configuration Id and Content cannot be updated by the user.", - "operationId": "Service_CreateOrUpdateConfiguration", + "operationId": "Configuration_CreateOrUpdate", "consumes": [ "application/json" ], @@ -88,7 +88,7 @@ }, "delete": { "summary": "Delete the configuration for devices or modules of an IoT hub. This request requires the If-Match header. The client may specify the ETag for the device identity on the request in order to compare to the ETag maintained by the service for the purpose of optimistic concurrency. The delete operation is performed only if the ETag sent by the client matches the value maintained by the server, indicating that the device identity has not been modified since it was retrieved by the client. To force an unconditional delete, set If-Match to the wildcard character (*).", - "operationId": "Service_DeleteConfiguration", + "operationId": "Configuration_Delete", "consumes": [], "produces": [ "application/json" @@ -120,7 +120,7 @@ "/configurations": { "get": { "summary": "Get multiple configurations for devices or modules of an IoT Hub. Returns the specified number of configurations for Iot Hub. Pagination is not supported.", - "operationId": "Service_GetConfigurations", + "operationId": "Configuration_GetConfigurations", "consumes": [], "produces": [ "application/json" @@ -154,7 +154,7 @@ "post": { "summary": "Validates the target condition query and custom metric queries for a configuration.", "description": "Validates the target condition query and custom metric queries for a configuration.", - "operationId": "Service_TestConfigurationQueries", + "operationId": "Configuration_TestQueries", "consumes": [ "application/json" ], @@ -187,7 +187,7 @@ "/statistics/devices": { "get": { "summary": "Retrieves statistics about device identities in the IoT hub’s identity registry.", - "operationId": "Service_GetDeviceRegistryStatistics", + "operationId": "RegistryManager_GetDeviceStatistics", "consumes": [], "produces": [ "application/json" @@ -210,7 +210,7 @@ "/statistics/service": { "get": { "summary": "Retrieves service statistics for this IoT hub’s identity registry.", - "operationId": "Service_GetServiceStatistics", + "operationId": "RegistryManager_GetServiceStatistics", "consumes": [], "produces": [ "application/json" @@ -233,7 +233,7 @@ "/devices": { "get": { "summary": "Get the identities of multiple devices from the IoT hub identity registry. Not recommended. Use the IoT Hub query language to retrieve device twin and device identity information. See https://docs.microsoft.com/en-us/rest/api/iothub/service/queryiothub and https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-query-language for more information.", - "operationId": "Service_GetDevices", + "operationId": "RegistryManager_GetDevices", "consumes": [], "produces": [ "application/json" @@ -266,7 +266,7 @@ "post": { "summary": "Create, update, or delete the identities of multiple devices from the IoT hub identity registry.", "description": "Create, update, or delete the identiies of multiple devices from the IoT hub identity registry. A device identity can be specified only once in the list. Different operations (create, update, delete) on different devices are allowed. A maximum of 100 devices can be specified per invocation. For large scale operations, consider using the import feature using blob storage(https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-identity-registry#import-and-export-device-identities).", - "operationId": "Service_BulkCreateOrUpdateDevices", + "operationId": "RegistryManager_BulkDeviceCRUD", "consumes": [ "application/json" ], @@ -309,7 +309,7 @@ "post": { "summary": "Query an IoT hub to retrieve information regarding device twins using a SQL-like language.", "description": "Query an IoT hub to retrieve information regarding device twins using a SQL-like language. See https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-query-language for more information. Pagination of results is supported. This returns information about device twins only.", - "operationId": "Service_QueryIotHub", + "operationId": "RegistryManager_QueryIotHub", "consumes": [ "application/json" ], @@ -327,13 +327,38 @@ }, { "$ref": "#/parameters/api-version" + }, + { + "name": "x-ms-continuation", + "in": "header", + "required": false, + "type": "string" + }, + { + "name": "x-ms-max-item-count", + "in": "header", + "required": false, + "type": "string" } ], "responses": { "200": { "description": "Query result with continuation token if appropriate.", "schema": { - "$ref": "#/definitions/QueryResult" + "type": "array", + "items": { + "$ref": "#/definitions/Twin" + } + }, + "headers": { + "x-ms-item-type": { + "description": "Type of the list of items.", + "type": "string" + }, + "x-ms-continuation": { + "description": "Continuation token", + "type": "string" + } } } } @@ -343,7 +368,7 @@ "get": { "summary": "Retrieve a device from the identity registry of an IoT hub.", "description": "Retrieve a device from the identity registry of an IoT hub.", - "operationId": "Service_GetDevice", + "operationId": "RegistryManager_GetDevice", "consumes": [], "produces": [ "application/json" @@ -372,7 +397,7 @@ "put": { "summary": "Create or update the identity of a device in the identity registry of an IoT hub.", "description": "Create or update the identity of a device in the identity registry of an IoT hub. An ETag must not be specified for the create operation. An ETag must be specified for the update operation. Note that generationId and deviceId cannot be updated by the user.", - "operationId": "Service_CreateOrUpdateDevice", + "operationId": "RegistryManager_CreateOrUpdateDevice", "consumes": [ "application/json" ], @@ -417,7 +442,7 @@ "delete": { "summary": "Delete the identity of a device from the identity registry of an IoT hub.", "description": "Delete the identity of a device from the identity registry of an IoT hub. This request requires the If-Match header. The client may specify the ETag for the device identity on the request in order to compare to the ETag maintained by the service for the purpose of optimistic concurrency. The delete operation is performed only if the ETag sent by the client matches the value maintained by the server, indicating that the device identity has not been modified since it was retrieved by the client. To force an unconditional delete, set If-Match to the wildcard character (*).", - "operationId": "Service_DeleteDevice", + "operationId": "RegistryManager_DeleteDevice", "consumes": [], "produces": [ "application/json" @@ -451,7 +476,7 @@ "post": { "summary": "Applies the provided configuration content to the specified edge device.", "description": "Applies the provided configuration content to the specified edge device. Configuration content must have modules content", - "operationId": "Service_ApplyConfigurationOnEdgeDevice", + "operationId": "Configuration_ApplyOnEdgeDevice", "consumes": [ "application/json" ], @@ -496,7 +521,7 @@ "post": { "summary": "Create a new import/export job on an IoT hub.", "description": "Create a new import/export job on an IoT hub. See https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-identity-registry#import-and-export-device-identities for more information.", - "operationId": "Service_CreateImportExportJob", + "operationId": "JobClient_CreateImportExportJob", "consumes": [ "application/json" ], @@ -507,6 +532,7 @@ { "name": "jobProperties", "in": "body", + "description": "Specifies the job specification.", "required": true, "schema": { "$ref": "#/definitions/JobProperties" @@ -530,7 +556,7 @@ "get": { "summary": "Gets the status of all import/export jobs in an iot hub", "description": "Gets the status of all import/export jobs in an iot hub. See https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-identity-registry#import-and-export-device-identities for more information.", - "operationId": "Service_GetImportExportJobs", + "operationId": "JobClient_GetImportExportJobs", "consumes": [], "produces": [ "application/json" @@ -557,7 +583,7 @@ "get": { "summary": "Gets the status of an import or export job in an iot hub", "description": "Gets the status of an import or export job in an iot hub. See https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-identity-registry#import-and-export-device-identities for more information.", - "operationId": "Service_GetImportExportJob", + "operationId": "JobClient_GetImportExportJob", "consumes": [], "produces": [ "application/json" @@ -586,7 +612,7 @@ "delete": { "summary": "Cancels an import or export job in an IoT hub.", "description": "Cancels an import or export job in an IoT hub. See https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-identity-registry#import-and-export-device-identities for more information.", - "operationId": "Service_CancelImportExportJob", + "operationId": "JobClient_CancelImportExportJob", "consumes": [], "produces": [ "application/json" @@ -620,7 +646,7 @@ "delete": { "summary": "Deletes all the pending commands for this device from the IoT hub.", "description": "Deletes all the pending commands for this device from the IoT hub", - "operationId": "Service_PurgeCommandQueue", + "operationId": "RegistryManager_PurgeCommandQueue", "consumes": [], "produces": [ "application/json" @@ -650,7 +676,7 @@ "/faultInjection": { "get": { "summary": "Get FaultInjection entity", - "operationId": "Service_GetFaultInjection", + "operationId": "FaultInjection_Get", "consumes": [], "produces": [ "application/json" @@ -671,7 +697,7 @@ }, "put": { "summary": "Create or update FaultInjection entity", - "operationId": "Service_SetFaultInjection", + "operationId": "FaultInjection_Set", "consumes": [ "application/json" ], @@ -702,7 +728,7 @@ "get": { "summary": "Gets a device twin.", "description": "Gets a device twin. See https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-device-twins for more information.", - "operationId": "Service_GetTwin", + "operationId": "Twin_GetDeviceTwin", "consumes": [], "produces": [ "application/json" @@ -731,7 +757,7 @@ "put": { "summary": "Replaces tags and desired properties of a device twin.", "description": "Replaces a device twin. See https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-device-twins for more information.", - "operationId": "Service_ReplaceTwin", + "operationId": "Twin_ReplaceDeviceTwin", "consumes": [ "application/json" ], @@ -777,7 +803,7 @@ "patch": { "summary": "Updates tags and desired properties of a device twin.", "description": "Updates a device twin. See https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-device-twins for more information.", - "operationId": "Service_UpdateTwin", + "operationId": "Twin_UpdateDeviceTwin", "consumes": [ "application/json" ], @@ -825,7 +851,7 @@ "get": { "summary": "Gets a module twin.", "description": "Gets a module twin. See https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-device-twins for more information.", - "operationId": "Service_GetModuleTwin", + "operationId": "Twin_GetModuleTwin", "consumes": [], "produces": [ "application/json" @@ -861,7 +887,7 @@ "put": { "summary": "Replaces tags and desired properties of a module twin.", "description": "Replaces a module twin. See https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-device-twins for more information.", - "operationId": "Service_ReplaceModuleTwin", + "operationId": "Twin_ReplaceModuleTwin", "consumes": [ "application/json" ], @@ -914,7 +940,7 @@ "patch": { "summary": "Updates tags and desired properties of a module twin.", "description": "Updates a module twin. See https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-device-twins for more information.", - "operationId": "Service_UpdateModuleTwin", + "operationId": "Twin_UpdateModuleTwin", "consumes": [ "application/json" ], @@ -968,7 +994,7 @@ "/digitalTwins/{digitalTwinId}/interfaces": { "get": { "summary": "Gets the list of interfaces.", - "operationId": "DigitalTwin_GetInterfaces", + "operationId": "DigitalTwin_GetComponents", "consumes": [], "produces": [ "application/json" @@ -1002,7 +1028,7 @@ }, "patch": { "summary": "Updates desired properties of multiple interfaces.\r\n Example URI: \"digitalTwins/{digitalTwinId}/interfaces\"", - "operationId": "DigitalTwin_UpdateInterfaces", + "operationId": "DigitalTwin_UpdateComponent", "consumes": [ "application/json" ], @@ -1055,7 +1081,7 @@ "/digitalTwins/{digitalTwinId}/interfaces/{interfaceName}": { "get": { "summary": "Gets the interface of given interfaceId.\r\n Example URI: \"digitalTwins/{digitalTwinId}/interfaces/{interfaceName}\"", - "operationId": "DigitalTwin_GetInterface", + "operationId": "DigitalTwin_GetComponent", "consumes": [], "produces": [ "application/json" @@ -1095,11 +1121,91 @@ } } }, + "/messages/serviceBound/feedback": { + "get": { + "summary": "This method is used to retrieve feedback of a cloud-to-device message.", + "description": "This method is used to retrieve feedback of a cloud-to-device message See https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-messaging for more information. This capability is only available in the standard tier IoT Hub. For more information, see [Choose the right IoT Hub tier](https://aka.ms/scaleyouriotsolution).", + "operationId": "HttpRuntime_ReceiveFeedbackNotification", + "consumes": [], + "produces": [ + "application/json" + ], + "parameters": [ + { + "$ref": "#/parameters/api-version" + } + ], + "responses": { + "200": { + "description": "The feedback response object" + }, + "204": { + "description": "No Content Sent if feedback queue is empty" + } + } + } + }, + "/messages/serviceBound/feedback/{lockToken}": { + "delete": { + "summary": "This method completes a feedback message.", + "description": "This method completes a feedback message. The lockToken obtained when the message was received must be provided to resolve race conditions when completing, a feedback message. A completed message is deleted from the feedback queue. See https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-messaging for more information.", + "operationId": "HttpRuntime_CompleteFeedbackNotification", + "consumes": [], + "produces": [ + "application/json" + ], + "parameters": [ + { + "name": "lockToken", + "in": "path", + "description": "Lock token.", + "required": true, + "type": "string" + }, + { + "$ref": "#/parameters/api-version" + } + ], + "responses": { + "204": { + "description": "No Content" + } + } + } + }, + "/messages/serviceBound/feedback/{lockToken}/abandon": { + "post": { + "summary": "This method abandons a feedback message.", + "description": "This method abandons a feedback message. The lockToken obtained when the message was received must be provided to resolve race conditions when abandoning, a feedback message. A abandoned message is deleted from the feedback queue. See https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-messaging for more information.", + "operationId": "HttpRuntime_AbandonFeedbackNotification", + "consumes": [], + "produces": [ + "application/json" + ], + "parameters": [ + { + "name": "lockToken", + "in": "path", + "description": "Lock Token.", + "required": true, + "type": "string" + }, + { + "$ref": "#/parameters/api-version" + } + ], + "responses": { + "204": { + "description": "No Content" + } + } + } + }, "/jobs/v2/{id}": { "get": { "summary": "Retrieves details of a scheduled job from an IoT hub.", "description": "Retrieves details of a scheduled job from an IoT hub. See https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-jobs for more information.", - "operationId": "Service_GetJob", + "operationId": "JobClient_GetJob", "consumes": [], "produces": [ "application/json" @@ -1128,7 +1234,7 @@ "put": { "summary": "Creates a new job to schedule update twins or device direct methods on an IoT hub at a scheduled time.", "description": "Creates a new job to schedule update twins or device direct methods on an IoT hub at a scheduled time. See https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-jobs for more information.", - "operationId": "Service_CreateJob", + "operationId": "JobClient_CreateJob", "consumes": [ "application/json" ], @@ -1169,7 +1275,7 @@ "post": { "summary": "Cancels a scheduled job on an IoT hub.", "description": "Cancels a scheduled job on an IoT hub. See https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-devguide-jobs for more information.", - "operationId": "Service_CancelJob", + "operationId": "JobClient_CancelJob", "consumes": [], "produces": [ "application/json" @@ -1200,7 +1306,7 @@ "get": { "summary": "Query an IoT hub to retrieve information regarding jobs using the IoT Hub query language", "description": "Query an IoT hub to retrieve information regarding jobs using the IoT Hub query language. See https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-query-language for more information. Pagination of results is supported. This returns information about jobs only.", - "operationId": "Service_QueryJobs", + "operationId": "JobClient_QueryJobs", "consumes": [], "produces": [ "application/json" @@ -1237,7 +1343,7 @@ "/devices/{id}/modules": { "get": { "summary": "Retrieve all the module identities on the device.", - "operationId": "Service_GetModulesOnDevice", + "operationId": "RegistryManager_GetModulesOnDevice", "consumes": [], "produces": [ "application/json" @@ -1270,7 +1376,7 @@ "/devices/{id}/modules/{mid}": { "get": { "summary": "Retrieve the specified module identity on the device.", - "operationId": "Service_GetModule", + "operationId": "RegistryManager_GetModule", "consumes": [], "produces": [ "application/json" @@ -1305,7 +1411,7 @@ }, "put": { "summary": "Create or update the module identity for device in IoT hub. An ETag must not be specified for the create operation. An ETag must be specified for the update operation. Note that moduleId and generation cannot be updated by the user.", - "operationId": "Service_CreateOrUpdateModule", + "operationId": "RegistryManager_CreateOrUpdateModule", "consumes": [ "application/json" ], @@ -1362,7 +1468,7 @@ }, "delete": { "summary": "Delete the module identity for device of an IoT hub. This request requires the If-Match header. The client may specify the ETag for the device identity on the request in order to compare to the ETag maintained by the service for the purpose of optimistic concurrency. The delete operation is performed only if the ETag sent by the client matches the value maintained by the server, indicating that the device identity has not been modified since it was retrieved by the client. To force an unconditional delete, set If-Match to the wildcard character (*).", - "operationId": "Service_DeleteModule", + "operationId": "RegistryManager_DeleteModule", "consumes": [], "produces": [ "application/json" @@ -1438,7 +1544,36 @@ "type": "string" }, "x-ms-model-id": { - "description": "Id of the model returned.", + "description": "Digital twin model id.", + "type": "string" + }, + "x-ms-model-resolution-status": { + "description": "Digital twin model resolution status: enum [Pending, Success, NotFound, Failed, Resolved, Deleted]", + "type": "string" + }, + "x-ms-model-resolution-description": { + "description": "Digital twin model resolution status description.", + "type": "string" + } + } + }, + "204": { + "description": "Model is not resolved, See the 'x-ms-model-resolution-description' and 'x-ms-model-resolution-status' for the resolution status code and description.", + "headers": { + "ETag": { + "description": "ETag of the digital twin.", + "type": "string" + }, + "x-ms-model-id": { + "description": "Digital twin model id.", + "type": "string" + }, + "x-ms-model-resolution-status": { + "description": "Digital twin model resolution status: enum [Pending, Success, NotFound, Failed, Resolved, Deleted]", + "type": "string" + }, + "x-ms-model-resolution-description": { + "description": "Digital twin model resolution status description.", "type": "string" } } @@ -1450,7 +1585,7 @@ "post": { "summary": "Invoke a direct method on a device.", "description": "Invoke a direct method on a device. See https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-direct-methods for more information.", - "operationId": "Service_InvokeDeviceMethod", + "operationId": "DeviceMethod_InvokeDeviceMethod", "consumes": [ "application/json" ], @@ -1490,7 +1625,7 @@ "post": { "summary": "Invoke a direct method on a module of a device.", "description": "Invoke a direct method on a module of a device. See https://docs.microsoft.com/azure/iot-hub/iot-hub-devguide-direct-methods for more information", - "operationId": "Service_InvokeDeviceModuleMethod", + "operationId": "DeviceMethod_InvokeModuleMethod", "consumes": [ "application/json" ], @@ -1536,7 +1671,7 @@ "post": { "summary": "Invoke a digital twin interface command.", "description": "Invoke a digital twin interface command.", - "operationId": "DigitalTwin_InvokeInterfaceCommand", + "operationId": "DigitalTwin_InvokeComponentCommand", "consumes": [ "application/json" ], @@ -2033,6 +2168,8 @@ "InvalidDeviceScope", "ConnectionForcefullyClosedOnFaultInjection", "ConnectionRejectedOnFaultInjection", + "InvalidEndpointAuthenticationType", + "ManagedIdentityNotEnabled", "InvalidRouteTestInput", "InvalidSourceOnRoute", "RoutingNotEnabled", @@ -2056,6 +2193,7 @@ "InvalidPnPDesiredProperties", "InvalidPnPReportedProperties", "InvalidPnPWritableReportedProperties", + "InvalidDigitalTwinJsonPatch", "GenericUnauthorized", "IotHubNotFound", "IotHubUnauthorizedAccess", @@ -2073,6 +2211,7 @@ "RoutingEndpointResponseForbidden", "InvalidMessageExpiryTime", "OperationNotAvailableInCurrentTier", + "KeyEncryptionKeyRevoked", "DeviceModelMaxPropertiesExceeded", "DeviceModelMaxIndexablePropertiesExceeded", "IotDpsSuspended", @@ -2148,6 +2287,7 @@ "TooManyModulesOnDevice", "ConfigurationCountLimitExceeded", "DigitalTwinModelCountLimitExceeded", + "InterfaceNameCompressionModelCountLimitExceeded", "GenericUnsupportedMediaType", "IncompatibleDataType", "GenericTooManyRequests", @@ -2155,6 +2295,7 @@ "ThrottleBacklogLimitExceeded", "ThrottlingBacklogTimeout", "ThrottlingMaxActiveJobCountExceeded", + "DeviceThrottlingLimitExceeded", "ClientClosedRequest", "GenericServerError", "ServerError", @@ -2179,10 +2320,16 @@ "InvalidPartitionEpoch", "RestoreTimedOut", "StreamReservationFailure", + "SerializationError", "UnexpectedPropertyValue", "OrchestrationOperationFailed", "ModelRepoEndpointError", "ResolutionError", + "UnableToFetchCredentials", + "UnableToFetchTenantInfo", + "UnableToShareIdentity", + "UnableToExpandDiscoveryInfo", + "UnableToExpandComponentInfo", "GenericBadGateway", "InvalidResponseWhileProxying", "GenericServiceUnavailable", @@ -2196,6 +2343,7 @@ "DeviceUnavailable", "ConfigurationNotAvailable", "GroupNotAvailable", + "HostingServiceNotAvailable", "GenericGatewayTimeout", "GatewayTimeout" ], @@ -2243,170 +2391,6 @@ } } }, - "QueryResult": { - "description": "The query result.", - "type": "object", - "properties": { - "type": { - "description": "The query result type.", - "enum": [ - "unknown", - "twin", - "deviceJob", - "jobResponse", - "raw", - "enrollment", - "enrollmentGroup", - "deviceRegistration" - ], - "type": "string" - }, - "items": { - "description": "The query result items, as a collection.", - "type": "array", - "items": { - "type": "object" - } - }, - "continuationToken": { - "description": "Request continuation token.", - "type": "string" - } - } - }, - "JobProperties": { - "type": "object", - "properties": { - "jobId": { - "description": "System generated. Ignored at creation.", - "type": "string" - }, - "startTimeUtc": { - "format": "date-time", - "description": "System generated. Ignored at creation.", - "type": "string" - }, - "endTimeUtc": { - "format": "date-time", - "description": "System generated. Ignored at creation.\r\nRepresents the time the job stopped processing.", - "type": "string" - }, - "type": { - "description": "Required.\r\nThe type of job to execute.", - "enum": [ - "unknown", - "export", - "import", - "backup", - "readDeviceProperties", - "writeDeviceProperties", - "updateDeviceConfiguration", - "rebootDevice", - "factoryResetDevice", - "firmwareUpdate", - "scheduleDeviceMethod", - "scheduleUpdateTwin", - "restoreFromBackup", - "failoverDataCopy" - ], - "type": "string" - }, - "status": { - "description": "System generated. Ignored at creation.", - "enum": [ - "unknown", - "enqueued", - "running", - "completed", - "failed", - "cancelled", - "scheduled", - "queued" - ], - "type": "string" - }, - "progress": { - "format": "int32", - "description": "System generated. Ignored at creation.\r\nRepresents the percentage of completion.", - "type": "integer" - }, - "inputBlobContainerUri": { - "description": "URI containing SAS token to a blob container that contains registry data to sync.", - "type": "string" - }, - "inputBlobName": { - "description": "The blob name to be used when importing from the provided input blob container.", - "type": "string" - }, - "outputBlobContainerUri": { - "description": "URI containing SAS token to a blob container. This is used to output the status of the job and the results.", - "type": "string" - }, - "outputBlobName": { - "description": "The name of the blob that will be created in the provided output blob container. This blob will contain\r\nthe exported device registry information for the IoT Hub.", - "type": "string" - }, - "excludeKeysInExport": { - "description": "Optional for export jobs; ignored for other jobs. Default: false. If false, authorization keys are included\r\nin export output. Keys are exported as null otherwise.", - "type": "boolean" - }, - "failureReason": { - "description": "System genereated. Ignored at creation.\r\nIf status == failure, this represents a string containing the reason.", - "type": "string" - } - } - }, - "PurgeMessageQueueResult": { - "description": "Result of a device message queue purge operation.", - "type": "object", - "properties": { - "totalMessagesPurged": { - "format": "int32", - "type": "integer" - }, - "deviceId": { - "description": "The ID of the device whose messages are being purged.", - "type": "string" - }, - "moduleId": { - "description": "The ID of the device whose messages are being purged.", - "type": "string" - } - } - }, - "FaultInjectionProperties": { - "type": "object", - "properties": { - "IotHubName": { - "type": "string" - }, - "connection": { - "$ref": "#/definitions/FaultInjectionConnectionProperties" - }, - "lastUpdatedTimeUtc": { - "format": "date-time", - "description": "Service generated.", - "type": "string" - } - } - }, - "FaultInjectionConnectionProperties": { - "type": "object", - "properties": { - "action": { - "enum": [ - "None", - "CloseAll", - "Periodic" - ], - "type": "string" - }, - "blockDurationInMinutes": { - "format": "int32", - "type": "integer" - } - } - }, "Twin": { "description": "Twin Representation", "type": "object", @@ -2520,6 +2504,147 @@ } } }, + "JobProperties": { + "type": "object", + "properties": { + "jobId": { + "description": "System generated. Ignored at creation.", + "type": "string" + }, + "startTimeUtc": { + "format": "date-time", + "description": "System generated. Ignored at creation.", + "type": "string" + }, + "endTimeUtc": { + "format": "date-time", + "description": "System generated. Ignored at creation.\r\nRepresents the time the job stopped processing.", + "type": "string" + }, + "type": { + "description": "Required.\r\nThe type of job to execute.", + "enum": [ + "unknown", + "export", + "import", + "backup", + "readDeviceProperties", + "writeDeviceProperties", + "updateDeviceConfiguration", + "rebootDevice", + "factoryResetDevice", + "firmwareUpdate", + "scheduleDeviceMethod", + "scheduleUpdateTwin", + "restoreFromBackup", + "failoverDataCopy" + ], + "type": "string" + }, + "status": { + "description": "System generated. Ignored at creation.", + "enum": [ + "unknown", + "enqueued", + "running", + "completed", + "failed", + "cancelled", + "scheduled", + "queued" + ], + "type": "string" + }, + "progress": { + "format": "int32", + "description": "System generated. Ignored at creation.\r\nRepresents the percentage of completion.", + "type": "integer" + }, + "inputBlobContainerUri": { + "description": "URI containing SAS token to a blob container that contains registry data to sync.", + "type": "string" + }, + "inputBlobName": { + "description": "The blob name to be used when importing from the provided input blob container.", + "type": "string" + }, + "outputBlobContainerUri": { + "description": "URI containing SAS token to a blob container. This is used to output the status of the job and the results.", + "type": "string" + }, + "outputBlobName": { + "description": "The name of the blob that will be created in the provided output blob container. This blob will contain\r\nthe exported device registry information for the IoT Hub.", + "type": "string" + }, + "excludeKeysInExport": { + "description": "Optional for export jobs; ignored for other jobs. Default: false. If false, authorization keys are included\r\nin export output. Keys are exported as null otherwise.", + "type": "boolean" + }, + "storageAuthenticationType": { + "description": "Specifies authentication type being used for connecting to storage account.", + "enum": [ + "keyBased", + "identityBased" + ], + "type": "string" + }, + "failureReason": { + "description": "System genereated. Ignored at creation.\r\nIf status == failure, this represents a string containing the reason.", + "type": "string" + } + } + }, + "PurgeMessageQueueResult": { + "description": "Result of a device message queue purge operation.", + "type": "object", + "properties": { + "totalMessagesPurged": { + "format": "int32", + "type": "integer" + }, + "deviceId": { + "description": "The ID of the device whose messages are being purged.", + "type": "string" + }, + "moduleId": { + "description": "The ID of the device whose messages are being purged.", + "type": "string" + } + } + }, + "FaultInjectionProperties": { + "type": "object", + "properties": { + "IotHubName": { + "type": "string" + }, + "connection": { + "$ref": "#/definitions/FaultInjectionConnectionProperties" + }, + "lastUpdatedTimeUtc": { + "format": "date-time", + "description": "Service generated.", + "type": "string" + } + } + }, + "FaultInjectionConnectionProperties": { + "type": "object", + "properties": { + "action": { + "enum": [ + "None", + "CloseAll", + "Periodic" + ], + "type": "string" + }, + "blockDurationInMinutes": { + "format": "int32", + "type": "integer" + } + } + }, "DigitalTwinInterfaces": { "type": "object", "properties": { @@ -2793,6 +2918,37 @@ } } }, + "QueryResult": { + "description": "The query result.", + "type": "object", + "properties": { + "type": { + "description": "The query result type.", + "enum": [ + "unknown", + "twin", + "deviceJob", + "jobResponse", + "raw", + "enrollment", + "enrollmentGroup", + "deviceRegistration" + ], + "type": "string" + }, + "items": { + "description": "The query result items, as a collection.", + "type": "array", + "items": { + "type": "object" + } + }, + "continuationToken": { + "description": "Request continuation token.", + "type": "string" + } + } + }, "Module": { "description": "Module identity on a device", "type": "object", @@ -2891,7 +3047,7 @@ "description": "Version of the Api.", "required": true, "type": "string", - "default": "2019-07-01-preview" + "default": "2020-03-13" } } } \ No newline at end of file diff --git a/azure-iot-hub/setup.cfg b/azure-iot-hub/setup.cfg new file mode 100644 index 000000000..ac9d3bb06 --- /dev/null +++ b/azure-iot-hub/setup.cfg @@ -0,0 +1,12 @@ +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +#-------------------------------------------------------------------------- + +[bdist_wheel] +# This flag says to generate wheels that support both Python 2 and Python +# 3. If your code will not run unchanged on both Python 2 and 3, you will +# need to generate separate wheels for each Python version that you +# support. +universal=1 diff --git a/azure-iot-hub/setup.py b/azure-iot-hub/setup.py index 3562ac498..ab4cc0641 100644 --- a/azure-iot-hub/setup.py +++ b/azure-iot-hub/setup.py @@ -36,7 +36,7 @@ setup( version=constant["VERSION"], description="Microsoft Azure IoTHub Service Library", license="MIT License", - url="https://github.com/Azure/azure-iot-sdk-python-preview", + url="https://github.com/Azure/azure-iot-sdk-python/tree/master/azure-iot-hub", author="Microsoft Corporation", author_email="opensource@microsoft.com", long_description=_long_description, @@ -54,8 +54,9 @@ setup( "Programming Language :: Python :: 3.5", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", ], - install_requires=["msrest"], + install_requires=["msrest", "uamqp"], python_requires=">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3*, <4", packages=find_packages( exclude=[ diff --git a/azure-iot-hub/tests/__init__.py b/azure-iot-hub/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/azure-iot-hub/tests/test_iothub_amqp_client.py b/azure-iot-hub/tests/test_iothub_amqp_client.py new file mode 100644 index 000000000..2c4bae91d --- /dev/null +++ b/azure-iot-hub/tests/test_iothub_amqp_client.py @@ -0,0 +1,72 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import pytest +import time +import base64 +import hmac +import hashlib +import copy +import logging +import uamqp +from azure.iot.hub.iothub_amqp_client import IoTHubAmqpClient + +try: + from urllib import quote, quote_plus, urlencode # Py2 +except Exception: + from urllib.parse import quote, quote_plus, urlencode + +"""---Constants---""" + +fake_shared_access_key = "Zm9vYmFy" +fake_shared_access_key_name = "test_key_name" +fake_hostname = "hostname.mytest-net" +fake_device_id = "device_id" +fake_message = "fake_message" + +"""----Shared fixtures----""" + + +@pytest.fixture(scope="function", autouse=True) +def mock_uamqp_SendClient(mocker): + mock_uamqp_SendClient = mocker.patch.object(uamqp, "SendClient") + return mock_uamqp_SendClient + + +@pytest.mark.describe("IoTHubAmqpClient - Amqp Client Connections") +class TestIoTHubAmqpClient(object): + @pytest.mark.it("Send Message To Device") + def test_send_message_to_device(self, mocker, mock_uamqp_SendClient): + iothub_amqp_client = IoTHubAmqpClient( + fake_hostname, fake_shared_access_key_name, fake_shared_access_key + ) + iothub_amqp_client.send_message_to_device(fake_device_id, fake_message) + amqp_client_obj = mock_uamqp_SendClient.return_value + + assert amqp_client_obj.queue_message.call_count == 1 + assert amqp_client_obj.send_all_messages.call_count == 1 + + @pytest.mark.it("Raises an Exception if send_all_messages Fails") + def test_raise_exception_on_send_fail(self, mocker, mock_uamqp_SendClient): + iothub_amqp_client = IoTHubAmqpClient( + fake_hostname, fake_shared_access_key_name, fake_shared_access_key + ) + amqp_client_obj = mock_uamqp_SendClient.return_value + mocker.patch.object( + amqp_client_obj, "send_all_messages", {uamqp.constants.MessageState.SendFailed} + ) + with pytest.raises(Exception): + iothub_amqp_client.send_message_to_device(fake_device_id, fake_message) + + @pytest.mark.it("Disconnect a Device") + def test_disconnect_sync(self, mocker, mock_uamqp_SendClient): + iothub_amqp_client = IoTHubAmqpClient( + fake_hostname, fake_shared_access_key_name, fake_shared_access_key + ) + amqp_client_obj = mock_uamqp_SendClient.return_value + iothub_amqp_client.disconnect_sync() + + assert amqp_client_obj.close.call_count == 1 diff --git a/azure-iot-hub/tests/test_iothub_configuration_manager.py b/azure-iot-hub/tests/test_iothub_configuration_manager.py new file mode 100644 index 000000000..321d92adc --- /dev/null +++ b/azure-iot-hub/tests/test_iothub_configuration_manager.py @@ -0,0 +1,158 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import pytest +from azure.iot.hub.protocol.models import AuthenticationMechanism +from azure.iot.hub.iothub_configuration_manager import IoTHubConfigurationManager + +"""---Constants---""" + +fake_shared_access_key = "Zm9vYmFy" +fake_shared_access_key_name = "alohomora" +fake_hostname = "beauxbatons.academy-net" +fake_device_id = "MyPensieve" +fake_etag = "taggedbymisnitryofmagic" +fake_configuration_id = "fake_configuration_id" + + +class fake_configuration_object: + id = fake_configuration_id + + +fake_configuration = fake_configuration_object() +fake_max_count = 42 +fake_configuration_queries = "fake_configuration_queries" +fake_configuration_content = "fake_configuration_content" + + +"""----Shared fixtures----""" + + +@pytest.fixture(scope="function", autouse=True) +def mock_configuration_operations(mocker): + mock_configuration_operations_init = mocker.patch( + "azure.iot.hub.protocol.iot_hub_gateway_service_ap_is.ConfigurationOperations" + ) + return mock_configuration_operations_init.return_value + + +@pytest.fixture(scope="function") +def iothub_configuration_manager(): + connection_string = "HostName={hostname};DeviceId={device_id};SharedAccessKeyName={skn};SharedAccessKey={sk}".format( + hostname=fake_hostname, + device_id=fake_device_id, + skn=fake_shared_access_key_name, + sk=fake_shared_access_key, + ) + iothub_configuration_manager = IoTHubConfigurationManager(connection_string) + return iothub_configuration_manager + + +@pytest.mark.describe("IoTHubConfigurationManager - .get_configuration()") +class TestGetConfiguration(object): + @pytest.mark.it("Gets configuration") + def test_get(self, mocker, mock_configuration_operations, iothub_configuration_manager): + iothub_configuration_manager.get_configuration(fake_configuration_id) + + assert mock_configuration_operations.get.call_count == 1 + assert mock_configuration_operations.get.call_args == mocker.call(fake_configuration_id) + + +@pytest.mark.describe("IoTHubConfigurationManager - .create_configuration()") +class TestCreateConfiguration(object): + @pytest.mark.it("Creates configuration") + def test_create_configuration( + self, mocker, mock_configuration_operations, iothub_configuration_manager + ): + iothub_configuration_manager.create_configuration(fake_configuration) + + assert mock_configuration_operations.create_or_update.call_count == 1 + assert mock_configuration_operations.create_or_update.call_args == mocker.call( + fake_configuration_id, fake_configuration + ) + + +@pytest.mark.describe("IoTHubConfigurationManager - .update_configuration()") +class TestUpdateConfiguration(object): + @pytest.mark.it("Updates configuration") + def test_update_configuration( + self, mocker, mock_configuration_operations, iothub_configuration_manager + ): + iothub_configuration_manager.update_configuration(fake_configuration, fake_etag) + + assert mock_configuration_operations.create_or_update.call_count == 1 + assert mock_configuration_operations.create_or_update.call_args == mocker.call( + fake_configuration_id, fake_configuration, fake_etag + ) + + +@pytest.mark.describe("IoTHubConfigurationManager - .delete_configuration()") +class TestDeleteConfiguration(object): + @pytest.mark.it("Deletes configuration") + def test_delete_configuration( + self, mocker, mock_configuration_operations, iothub_configuration_manager + ): + iothub_configuration_manager.delete_configuration(fake_configuration_id) + + assert mock_configuration_operations.delete.call_count == 1 + assert mock_configuration_operations.delete.call_args == mocker.call( + fake_configuration_id, "*" + ) + + @pytest.mark.it("Deletes configuration with an etag") + def test_delete_configuration_with_etag( + self, mocker, mock_configuration_operations, iothub_configuration_manager + ): + iothub_configuration_manager.delete_configuration( + configuration_id=fake_configuration_id, etag=fake_etag + ) + + assert mock_configuration_operations.delete.call_count == 1 + assert mock_configuration_operations.delete.call_args == mocker.call( + fake_configuration_id, fake_etag + ) + + +@pytest.mark.describe("IoTHubConfigurationManager - .get_configurations()") +class TestGetConfigurations(object): + @pytest.mark.it("Get configurations") + def test_get_configurations( + self, mocker, mock_configuration_operations, iothub_configuration_manager + ): + iothub_configuration_manager.get_configurations(fake_max_count) + + assert mock_configuration_operations.get_configurations.call_count == 1 + assert mock_configuration_operations.get_configurations.call_args == mocker.call( + fake_max_count + ) + + +@pytest.mark.describe("IoTHubConfigurationManager - .test_configuration_queries()") +class TestTestConfigurationQueries(object): + @pytest.mark.it("Test test_configuration_queries") + def test_test_configuration_queries( + self, mocker, mock_configuration_operations, iothub_configuration_manager + ): + iothub_configuration_manager.test_configuration_queries(fake_configuration_queries) + assert mock_configuration_operations.test_queries.call_count == 1 + assert mock_configuration_operations.test_queries.call_args == mocker.call( + fake_configuration_queries + ) + + +@pytest.mark.describe("IoTHubConfigurationManager - .apply_configuration_on_edge_device()") +class TestApplyConfigurationOnEdgeDevice(object): + @pytest.mark.it("Test apply configuration on edge device") + def test_apply_configuration_on_edge_device( + self, mocker, mock_configuration_operations, iothub_configuration_manager + ): + iothub_configuration_manager.apply_configuration_on_edge_device( + fake_device_id, fake_configuration_content + ) + assert mock_configuration_operations.apply_on_edge_device.call_count == 1 + assert mock_configuration_operations.apply_on_edge_device.call_args == mocker.call( + fake_device_id, fake_configuration_content + ) diff --git a/azure-iot-hub/tests/test_iothub_http_runtime_manager.py b/azure-iot-hub/tests/test_iothub_http_runtime_manager.py new file mode 100644 index 000000000..d5084e8d6 --- /dev/null +++ b/azure-iot-hub/tests/test_iothub_http_runtime_manager.py @@ -0,0 +1,77 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import pytest +from azure.iot.hub.protocol.models import AuthenticationMechanism +from azure.iot.hub.iothub_http_runtime_manager import IoTHubHttpRuntimeManager + +"""---Constants---""" + +fake_hostname = "beauxbatons.academy-net" +fake_device_id = "MyPensieve" +fake_shared_access_key_name = "alohomora" +fake_shared_access_key = "Zm9vYmFy" +fake_lock_token = "fake_lock_token" + + +"""----Shared fixtures----""" + + +@pytest.fixture(scope="function", autouse=True) +def mock_http_runtime_operations(mocker): + mock_http_runtime_operations_init = mocker.patch( + "azure.iot.hub.protocol.iot_hub_gateway_service_ap_is.HttpRuntimeOperations" + ) + return mock_http_runtime_operations_init.return_value + + +@pytest.fixture(scope="function") +def iothub_http_runtime_manager(): + connection_string = "HostName={hostname};DeviceId={device_id};SharedAccessKeyName={skn};SharedAccessKey={sk}".format( + hostname=fake_hostname, + device_id=fake_device_id, + skn=fake_shared_access_key_name, + sk=fake_shared_access_key, + ) + iothub_http_runtime_manager = IoTHubHttpRuntimeManager(connection_string) + return iothub_http_runtime_manager + + +@pytest.mark.describe("IoTHubHttpRuntimeManager - .receive_feedback_notification()") +class TestReceiveFeedbackNotification(object): + @pytest.mark.it("Receive feedback notification") + def test_receive_feedback_notification( + self, mocker, mock_http_runtime_operations, iothub_http_runtime_manager + ): + iothub_http_runtime_manager.receive_feedback_notification() + assert mock_http_runtime_operations.receive_feedback_notification.call_count == 1 + assert mock_http_runtime_operations.receive_feedback_notification.call_args == mocker.call() + + +@pytest.mark.describe("IoTHubHttpRuntimeManager - .complete_feedback_notification()") +class TestCompleteFeedbackNotification(object): + @pytest.mark.it("Complete feedback notification") + def test_complete_feedback_notification( + self, mocker, mock_http_runtime_operations, iothub_http_runtime_manager + ): + iothub_http_runtime_manager.complete_feedback_notification(fake_lock_token) + assert mock_http_runtime_operations.complete_feedback_notification.call_count == 1 + assert mock_http_runtime_operations.complete_feedback_notification.call_args == mocker.call( + fake_lock_token + ) + + +@pytest.mark.describe("IoTHubHttpRuntimeManager - .abandon_feedback_notification()") +class TestAbandonFeedbackNotification(object): + @pytest.mark.it("Abandon feedback notification") + def test_abandon_feedback_notification( + self, mocker, mock_http_runtime_operations, iothub_http_runtime_manager + ): + iothub_http_runtime_manager.abandon_feedback_notification(fake_lock_token) + assert mock_http_runtime_operations.abandon_feedback_notification.call_count == 1 + assert mock_http_runtime_operations.abandon_feedback_notification.call_args == mocker.call( + fake_lock_token + ) diff --git a/azure-iot-hub/tests/test_iothub_registry_manager.py b/azure-iot-hub/tests/test_iothub_registry_manager.py new file mode 100644 index 000000000..20ae293b8 --- /dev/null +++ b/azure-iot-hub/tests/test_iothub_registry_manager.py @@ -0,0 +1,1189 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import pytest +from azure.iot.hub.protocol.models import AuthenticationMechanism +from azure.iot.hub.iothub_registry_manager import IoTHubRegistryManager +from azure.iot.hub.iothub_amqp_client import IoTHubAmqpClient as iothub_amqp_client + +"""---Constants---""" + +fake_shared_access_key = "Zm9vYmFy" +fake_shared_access_key_name = "alohomora" + +fake_primary_key = "petrificus" +fake_secondary_key = "totalus" +fake_primary_thumbprint = "HELFKCPOXAIR9PVNOA3" +fake_secondary_thumbprint = "RGSHARLU4VYYFENINUF" +fake_hostname = "beauxbatons.academy-net" +fake_device_id = "MyPensieve" +fake_module_id = "Divination" +fake_managed_by = "Hogwarts" +fake_etag = "taggedbymisnitryofmagic" +fake_status = "flying" +fake_configuration_id = "fake_configuration" +fake_configuration = "fake_config" +fake_max_count = 42 +fake_configuration_queries = "fake_configuration_queries" +fake_devices = "fake_devices" +fake_query_specification = "fake_query_specification" +fake_configuration_content = "fake_configuration_content" +fake_job_id = "fake_job_id" +fake_start_time = "fake_start_time" +fake_end_time = "fake_end_time" +fake_job_type = "fake_job_type" +fake_job_request = "fake_job_request" +fake_job_status = "fake_status" +fake_job_properties = "fake_job_properties" +fake_device_twin = "fake_device_twin" +fake_module_twin = "fake_module_twin" +fake_direct_method_request = "fake_direct_method_request" +fake_message_to_send = "fake_message_to_send" + +"""----Shared fixtures----""" + + +@pytest.fixture(scope="function", autouse=True) +def mock_registry_manager_operations(mocker): + mock_registry_manager_operations_init = mocker.patch( + "azure.iot.hub.protocol.iot_hub_gateway_service_ap_is.RegistryManagerOperations" + ) + return mock_registry_manager_operations_init.return_value + + +@pytest.fixture(scope="function", autouse=True) +def mock_twin_operations(mocker): + mock_twin_operations_init = mocker.patch( + "azure.iot.hub.protocol.iot_hub_gateway_service_ap_is.TwinOperations" + ) + return mock_twin_operations_init.return_value + + +@pytest.fixture(scope="function", autouse=True) +def mock_device_method_operations(mocker): + mock_device_method_operations_init = mocker.patch( + "azure.iot.hub.protocol.iot_hub_gateway_service_ap_is.DeviceMethodOperations" + ) + return mock_device_method_operations_init.return_value + + +@pytest.fixture(scope="function") +def iothub_registry_manager(): + connection_string = "HostName={hostname};DeviceId={device_id};SharedAccessKeyName={skn};SharedAccessKey={sk}".format( + hostname=fake_hostname, + device_id=fake_device_id, + skn=fake_shared_access_key_name, + sk=fake_shared_access_key, + ) + iothub_registry_manager = IoTHubRegistryManager(connection_string) + return iothub_registry_manager + + +@pytest.fixture(scope="function") +def mock_device_constructor(mocker): + return mocker.patch("azure.iot.hub.iothub_registry_manager.Device") + + +@pytest.fixture(scope="function") +def mock_module_constructor(mocker): + return mocker.patch("azure.iot.hub.iothub_registry_manager.Module") + + +@pytest.fixture(scope="function") +def mock_uamqp_send_message_to_device(mocker): + mock_uamqp_send = mocker.patch.object(iothub_amqp_client, "send_message_to_device") + return mock_uamqp_send + + +@pytest.fixture +def mock_uamqp_disconnect_sync(mocker): + return mocker.patch.object(iothub_amqp_client, "disconnect_sync") + + +@pytest.mark.describe("IoTHubRegistryManager - .create_device_with_sas()") +class TestCreateDeviceWithSymmetricKey(object): + + testdata = [ + (fake_primary_key, None), + (None, fake_secondary_key), + (fake_primary_key, fake_secondary_key), + ] + + @pytest.mark.it("Initializes device with device id, status and sas auth") + @pytest.mark.parametrize( + "primary_key, secondary_key", testdata, ids=["Primary Key", "Secondary Key", "Both Keys"] + ) + def test_initializes_device_with_kwargs_for_sas( + self, iothub_registry_manager, mock_device_constructor, primary_key, secondary_key + ): + iothub_registry_manager.create_device_with_sas( + device_id=fake_device_id, + status=fake_status, + primary_key=primary_key, + secondary_key=secondary_key, + ) + + assert mock_device_constructor.call_count == 1 + + assert mock_device_constructor.call_args[1]["device_id"] == fake_device_id + assert mock_device_constructor.call_args[1]["status"] == fake_status + assert isinstance( + mock_device_constructor.call_args[1]["authentication"], AuthenticationMechanism + ) + auth_mechanism = mock_device_constructor.call_args[1]["authentication"] + assert auth_mechanism.type == "sas" + assert auth_mechanism.x509_thumbprint is None + sym_key = auth_mechanism.symmetric_key + assert sym_key.primary_key == primary_key + assert sym_key.secondary_key == secondary_key + + @pytest.mark.it( + "Calls method from service operations with device id and previously constructed device" + ) + @pytest.mark.parametrize( + "primary_key, secondary_key", testdata, ids=["Primary Key", "Secondary Key", "Both Keys"] + ) + def test_calls_create_or_update_device_for_sas( + self, + mock_device_constructor, + mock_registry_manager_operations, + iothub_registry_manager, + primary_key, + secondary_key, + ): + iothub_registry_manager.create_device_with_sas( + device_id=fake_device_id, + status=fake_status, + primary_key=primary_key, + secondary_key=secondary_key, + ) + + assert mock_registry_manager_operations.create_or_update_device.call_count == 1 + assert ( + mock_registry_manager_operations.create_or_update_device.call_args[0][0] + == fake_device_id + ) + assert ( + mock_registry_manager_operations.create_or_update_device.call_args[0][1] + == mock_device_constructor.return_value + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .create_device_with_x509()") +class TestCreateDeviceWithX509(object): + + testdata = [ + (fake_primary_thumbprint, None), + (None, fake_secondary_thumbprint), + (fake_primary_thumbprint, fake_secondary_thumbprint), + ] + + @pytest.mark.it("Initializes device with device id, status and X509 auth") + @pytest.mark.parametrize( + "primary_thumbprint, secondary_thumbprint", + testdata, + ids=["Primary Thumbprint", "Secondary Thumbprint", "Both Thumbprints"], + ) + def test_initializes_device_with_kwargs_for_x509( + self, + iothub_registry_manager, + mock_device_constructor, + primary_thumbprint, + secondary_thumbprint, + ): + iothub_registry_manager.create_device_with_x509( + device_id=fake_device_id, + status=fake_status, + primary_thumbprint=primary_thumbprint, + secondary_thumbprint=secondary_thumbprint, + ) + + assert mock_device_constructor.call_count == 1 + assert mock_device_constructor.call_args[1]["device_id"] == fake_device_id + assert mock_device_constructor.call_args[1]["status"] == fake_status + assert isinstance( + mock_device_constructor.call_args[1]["authentication"], AuthenticationMechanism + ) + auth_mechanism = mock_device_constructor.call_args[1]["authentication"] + assert auth_mechanism.type == "selfSigned" + assert auth_mechanism.symmetric_key is None + x509_thumbprint = auth_mechanism.x509_thumbprint + assert x509_thumbprint.primary_thumbprint == primary_thumbprint + assert x509_thumbprint.secondary_thumbprint == secondary_thumbprint + + @pytest.mark.it( + "Calls method from service operations with device id and previously constructed device" + ) + @pytest.mark.parametrize( + "primary_thumbprint, secondary_thumbprint", + testdata, + ids=["Primary Thumbprint", "Secondary Thumbprint", "Both Thumbprints"], + ) + def test_calls_create_or_update_device_for_x509( + self, + mock_device_constructor, + mock_registry_manager_operations, + iothub_registry_manager, + primary_thumbprint, + secondary_thumbprint, + ): + iothub_registry_manager.create_device_with_x509( + device_id=fake_device_id, + status=fake_status, + primary_thumbprint=primary_thumbprint, + secondary_thumbprint=secondary_thumbprint, + ) + + assert mock_registry_manager_operations.create_or_update_device.call_count == 1 + assert ( + mock_registry_manager_operations.create_or_update_device.call_args[0][0] + == fake_device_id + ) + assert ( + mock_registry_manager_operations.create_or_update_device.call_args[0][1] + == mock_device_constructor.return_value + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .create_device_with_certificate_authority()") +class TestCreateDeviceWithCA(object): + @pytest.mark.it("Initializes device with device id, status and ca auth") + def test_initializes_device_with_kwargs_for_certificate_authority( + self, mock_device_constructor, iothub_registry_manager + ): + iothub_registry_manager.create_device_with_certificate_authority( + device_id=fake_device_id, status=fake_status + ) + + assert mock_device_constructor.call_count == 1 + assert mock_device_constructor.call_args[1]["device_id"] == fake_device_id + assert mock_device_constructor.call_args[1]["status"] == fake_status + assert isinstance( + mock_device_constructor.call_args[1]["authentication"], AuthenticationMechanism + ) + auth_mechanism = mock_device_constructor.call_args[1]["authentication"] + assert auth_mechanism.type == "certificateAuthority" + assert auth_mechanism.x509_thumbprint is None + assert auth_mechanism.symmetric_key is None + + @pytest.mark.it( + "Calls method from service operations with device id and previously constructed device" + ) + def test_calls_create_or_update_device_for_certificate_authority( + self, mock_device_constructor, mock_registry_manager_operations, iothub_registry_manager + ): + iothub_registry_manager.create_device_with_certificate_authority( + device_id=fake_device_id, status=fake_status + ) + + assert mock_registry_manager_operations.create_or_update_device.call_count == 1 + assert ( + mock_registry_manager_operations.create_or_update_device.call_args[0][0] + == fake_device_id + ) + assert ( + mock_registry_manager_operations.create_or_update_device.call_args[0][1] + == mock_device_constructor.return_value + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .update_device_with_sas()") +class TestUpdateDeviceWithSymmetricKey(object): + + testdata = [ + (fake_primary_key, None), + (None, fake_secondary_key), + (fake_primary_key, fake_secondary_key), + ] + + @pytest.mark.it("Initializes device with device id, status, etag and sas auth") + @pytest.mark.parametrize( + "primary_key, secondary_key", testdata, ids=["Primary Key", "Secondary Key", "Both Keys"] + ) + def test_initializes_device_with_kwargs_for_sas( + self, iothub_registry_manager, mock_device_constructor, primary_key, secondary_key + ): + iothub_registry_manager.update_device_with_sas( + device_id=fake_device_id, + status=fake_status, + etag=fake_etag, + primary_key=primary_key, + secondary_key=secondary_key, + ) + + assert mock_device_constructor.call_count == 1 + + assert mock_device_constructor.call_args[1]["device_id"] == fake_device_id + assert mock_device_constructor.call_args[1]["status"] == fake_status + assert isinstance( + mock_device_constructor.call_args[1]["authentication"], AuthenticationMechanism + ) + auth_mechanism = mock_device_constructor.call_args[1]["authentication"] + assert auth_mechanism.type == "sas" + assert auth_mechanism.x509_thumbprint is None + sym_key = auth_mechanism.symmetric_key + assert sym_key.primary_key == primary_key + assert sym_key.secondary_key == secondary_key + assert mock_device_constructor.call_args[1]["etag"] == fake_etag + + @pytest.mark.it( + "Calls method from service operations with device id and previously constructed device" + ) + @pytest.mark.parametrize( + "primary_key, secondary_key", testdata, ids=["Primary Key", "Secondary Key", "Both Keys"] + ) + def test_calls_create_or_update_device_for_sas( + self, + mock_device_constructor, + mock_registry_manager_operations, + iothub_registry_manager, + primary_key, + secondary_key, + ): + iothub_registry_manager.update_device_with_sas( + device_id=fake_device_id, + status=fake_status, + etag=fake_etag, + primary_key=primary_key, + secondary_key=secondary_key, + ) + + assert mock_registry_manager_operations.create_or_update_device.call_count == 1 + assert ( + mock_registry_manager_operations.create_or_update_device.call_args[0][0] + == fake_device_id + ) + assert ( + mock_registry_manager_operations.create_or_update_device.call_args[0][1] + == mock_device_constructor.return_value + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .update_device_with_x509()") +class TestUpdateDeviceWithX509(object): + + testdata = [ + (fake_primary_thumbprint, None), + (None, fake_secondary_thumbprint), + (fake_primary_thumbprint, fake_secondary_thumbprint), + ] + + @pytest.mark.it("Initializes device with device id, status and X509 auth") + @pytest.mark.parametrize( + "primary_thumbprint, secondary_thumbprint", + testdata, + ids=["Primary Thumbprint", "Secondary Thumbprint", "Both Thumbprints"], + ) + def test_initializes_device_with_kwargs_for_x509( + self, + iothub_registry_manager, + mock_device_constructor, + primary_thumbprint, + secondary_thumbprint, + ): + iothub_registry_manager.update_device_with_x509( + device_id=fake_device_id, + status=fake_status, + etag=fake_etag, + primary_thumbprint=primary_thumbprint, + secondary_thumbprint=secondary_thumbprint, + ) + + assert mock_device_constructor.call_count == 1 + assert mock_device_constructor.call_args[1]["device_id"] == fake_device_id + assert mock_device_constructor.call_args[1]["status"] == fake_status + assert isinstance( + mock_device_constructor.call_args[1]["authentication"], AuthenticationMechanism + ) + auth_mechanism = mock_device_constructor.call_args[1]["authentication"] + assert auth_mechanism.type == "selfSigned" + assert auth_mechanism.symmetric_key is None + x509_thumbprint = auth_mechanism.x509_thumbprint + assert x509_thumbprint.primary_thumbprint == primary_thumbprint + assert x509_thumbprint.secondary_thumbprint == secondary_thumbprint + assert mock_device_constructor.call_args[1]["etag"] == fake_etag + + @pytest.mark.it( + "Calls method from service operations with device id and previously constructed device" + ) + @pytest.mark.parametrize( + "primary_thumbprint, secondary_thumbprint", + testdata, + ids=["Primary Thumbprint", "Secondary Thumbprint", "Both Thumbprints"], + ) + def test_calls_create_or_update_device_for_x509( + self, + mock_device_constructor, + mock_registry_manager_operations, + iothub_registry_manager, + primary_thumbprint, + secondary_thumbprint, + ): + iothub_registry_manager.update_device_with_x509( + device_id=fake_device_id, + status=fake_status, + etag=fake_etag, + primary_thumbprint=primary_thumbprint, + secondary_thumbprint=secondary_thumbprint, + ) + + assert mock_registry_manager_operations.create_or_update_device.call_count == 1 + assert ( + mock_registry_manager_operations.create_or_update_device.call_args[0][0] + == fake_device_id + ) + assert ( + mock_registry_manager_operations.create_or_update_device.call_args[0][1] + == mock_device_constructor.return_value + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .update_device_with_certificate_authority()") +class TestUpdateDeviceWithCA(object): + @pytest.mark.it("Initializes device with device id, status and ca auth") + def test_initializes_device_with_kwargs_for_certificate_authority( + self, mock_device_constructor, iothub_registry_manager + ): + iothub_registry_manager.update_device_with_certificate_authority( + device_id=fake_device_id, status=fake_status, etag=fake_etag + ) + + assert mock_device_constructor.call_count == 1 + assert mock_device_constructor.call_args[1]["device_id"] == fake_device_id + assert mock_device_constructor.call_args[1]["status"] == fake_status + assert isinstance( + mock_device_constructor.call_args[1]["authentication"], AuthenticationMechanism + ) + auth_mechanism = mock_device_constructor.call_args[1]["authentication"] + assert auth_mechanism.type == "certificateAuthority" + assert auth_mechanism.x509_thumbprint is None + assert auth_mechanism.symmetric_key is None + assert mock_device_constructor.call_args[1]["etag"] == fake_etag + + @pytest.mark.it( + "Calls method from service operations with device id and previously constructed device" + ) + def test_calls_create_or_update_device_for_certificate_authority( + self, mock_device_constructor, mock_registry_manager_operations, iothub_registry_manager + ): + iothub_registry_manager.update_device_with_certificate_authority( + device_id=fake_device_id, status=fake_status, etag=fake_etag + ) + + assert mock_registry_manager_operations.create_or_update_device.call_count == 1 + assert ( + mock_registry_manager_operations.create_or_update_device.call_args[0][0] + == fake_device_id + ) + assert ( + mock_registry_manager_operations.create_or_update_device.call_args[0][1] + == mock_device_constructor.return_value + ) + + +@pytest.mark.describe("IoTHubRegistryManager -- .get_device()") +class TestGetDevice(object): + @pytest.mark.it("Gets device from service for provided device id") + def test_get_device(self, mocker, mock_registry_manager_operations, iothub_registry_manager): + iothub_registry_manager.get_device(fake_device_id) + + assert mock_registry_manager_operations.get_device.call_count == 1 + assert mock_registry_manager_operations.get_device.call_args == mocker.call(fake_device_id) + + +@pytest.mark.describe("IoTHubRegistryManager - .delete_device()") +class TestDeleteDevice(object): + @pytest.mark.it("Deletes device for the provided device id") + def test_delete_device(self, mocker, mock_registry_manager_operations, iothub_registry_manager): + iothub_registry_manager.delete_device(fake_device_id) + + assert mock_registry_manager_operations.delete_device.call_count == 1 + assert mock_registry_manager_operations.delete_device.call_args == mocker.call( + fake_device_id, "*" + ) + + @pytest.mark.it("Deletes device with an etag for the provided device id and etag") + def test_delete_device_with_etag( + self, mocker, mock_registry_manager_operations, iothub_registry_manager + ): + iothub_registry_manager.delete_device(device_id=fake_device_id, etag=fake_etag) + + assert mock_registry_manager_operations.delete_device.call_count == 1 + assert mock_registry_manager_operations.delete_device.call_args == mocker.call( + fake_device_id, fake_etag + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .create_module_with_sas()") +class TestCreateModuleWithSymmetricKey(object): + + testdata = [ + (fake_primary_key, None), + (None, fake_secondary_key), + (fake_primary_key, fake_secondary_key), + ] + + @pytest.mark.it("Initializes module with device id, module id, managed_by and sas auth") + @pytest.mark.parametrize( + "primary_key, secondary_key", testdata, ids=["Primary Key", "Secondary Key", "Both Keys"] + ) + def test_initializes_device_with_kwargs_for_sas( + self, iothub_registry_manager, mock_module_constructor, primary_key, secondary_key + ): + iothub_registry_manager.create_module_with_sas( + device_id=fake_device_id, + module_id=fake_module_id, + managed_by=fake_managed_by, + primary_key=primary_key, + secondary_key=secondary_key, + ) + + assert mock_module_constructor.call_count == 1 + + assert mock_module_constructor.call_args[1]["module_id"] == fake_module_id + assert mock_module_constructor.call_args[1]["managed_by"] == fake_managed_by + assert mock_module_constructor.call_args[1]["device_id"] == fake_device_id + assert isinstance( + mock_module_constructor.call_args[1]["authentication"], AuthenticationMechanism + ) + auth_mechanism = mock_module_constructor.call_args[1]["authentication"] + assert auth_mechanism.type == "sas" + assert auth_mechanism.x509_thumbprint is None + sym_key = auth_mechanism.symmetric_key + assert sym_key.primary_key == primary_key + assert sym_key.secondary_key == secondary_key + + @pytest.mark.it( + "Calls method from service operations with device id, module id and previously constructed module" + ) + @pytest.mark.parametrize( + "primary_key, secondary_key", testdata, ids=["Primary Key", "Secondary Key", "Both Keys"] + ) + def test_calls_create_or_update_device_for_sas( + self, + mock_module_constructor, + mock_registry_manager_operations, + iothub_registry_manager, + primary_key, + secondary_key, + ): + iothub_registry_manager.create_module_with_sas( + device_id=fake_device_id, + module_id=fake_module_id, + managed_by=fake_managed_by, + primary_key=primary_key, + secondary_key=secondary_key, + ) + + assert mock_registry_manager_operations.create_or_update_module.call_count == 1 + assert ( + mock_registry_manager_operations.create_or_update_module.call_args[0][0] + == fake_device_id + ) + assert ( + mock_registry_manager_operations.create_or_update_module.call_args[0][1] + == fake_module_id + ) + assert ( + mock_registry_manager_operations.create_or_update_module.call_args[0][2] + == mock_module_constructor.return_value + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .create_module_with_x509()") +class TestCreateModuleWithX509(object): + + testdata = [ + (fake_primary_thumbprint, None), + (None, fake_secondary_thumbprint), + (fake_primary_thumbprint, fake_secondary_thumbprint), + ] + + @pytest.mark.it("Initializes module with device id, module id, managed_by and X509 auth") + @pytest.mark.parametrize( + "primary_thumbprint, secondary_thumbprint", + testdata, + ids=["Primary Thumbprint", "Secondary Thumbprint", "Both Thumbprints"], + ) + def test_initializes_device_with_kwargs_for_x509( + self, + iothub_registry_manager, + mock_module_constructor, + primary_thumbprint, + secondary_thumbprint, + ): + iothub_registry_manager.create_module_with_x509( + device_id=fake_device_id, + module_id=fake_module_id, + managed_by=fake_managed_by, + primary_thumbprint=primary_thumbprint, + secondary_thumbprint=secondary_thumbprint, + ) + + assert mock_module_constructor.call_count == 1 + assert mock_module_constructor.call_args[1]["module_id"] == fake_module_id + assert mock_module_constructor.call_args[1]["managed_by"] == fake_managed_by + assert mock_module_constructor.call_args[1]["device_id"] == fake_device_id + assert isinstance( + mock_module_constructor.call_args[1]["authentication"], AuthenticationMechanism + ) + auth_mechanism = mock_module_constructor.call_args[1]["authentication"] + assert auth_mechanism.type == "selfSigned" + assert auth_mechanism.symmetric_key is None + x509_thumbprint = auth_mechanism.x509_thumbprint + assert x509_thumbprint.primary_thumbprint == primary_thumbprint + assert x509_thumbprint.secondary_thumbprint == secondary_thumbprint + + @pytest.mark.it( + "Calls method from service operations with device id, module id and previously constructed module" + ) + @pytest.mark.parametrize( + "primary_thumbprint, secondary_thumbprint", + testdata, + ids=["Primary Thumbprint", "Secondary Thumbprint", "Both Thumbprints"], + ) + def test_calls_create_or_update_device_for_x509( + self, + mock_module_constructor, + mock_registry_manager_operations, + iothub_registry_manager, + primary_thumbprint, + secondary_thumbprint, + ): + iothub_registry_manager.create_module_with_x509( + device_id=fake_device_id, + module_id=fake_module_id, + managed_by=fake_managed_by, + primary_thumbprint=primary_thumbprint, + secondary_thumbprint=secondary_thumbprint, + ) + + assert mock_registry_manager_operations.create_or_update_module.call_count == 1 + assert ( + mock_registry_manager_operations.create_or_update_module.call_args[0][0] + == fake_device_id + ) + assert ( + mock_registry_manager_operations.create_or_update_module.call_args[0][1] + == fake_module_id + ) + assert ( + mock_registry_manager_operations.create_or_update_module.call_args[0][2] + == mock_module_constructor.return_value + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .create_module_with_certificate_authority()") +class TestCreateModuleWithCA(object): + @pytest.mark.it("Initializes module with device id, module id, managed_by and ca auth") + def test_initializes_device_with_kwargs_for_certificate_authority( + self, mock_module_constructor, iothub_registry_manager + ): + iothub_registry_manager.create_module_with_certificate_authority( + device_id=fake_device_id, module_id=fake_module_id, managed_by=fake_managed_by + ) + + assert mock_module_constructor.call_count == 1 + assert mock_module_constructor.call_args[1]["module_id"] == fake_module_id + assert mock_module_constructor.call_args[1]["managed_by"] == fake_managed_by + assert mock_module_constructor.call_args[1]["device_id"] == fake_device_id + assert isinstance( + mock_module_constructor.call_args[1]["authentication"], AuthenticationMechanism + ) + auth_mechanism = mock_module_constructor.call_args[1]["authentication"] + assert auth_mechanism.type == "certificateAuthority" + assert auth_mechanism.x509_thumbprint is None + assert auth_mechanism.symmetric_key is None + + @pytest.mark.it( + "Calls method from service operations with device id, module id and previously constructed module" + ) + def test_calls_create_or_update_device_for_certificate_authority( + self, mock_module_constructor, mock_registry_manager_operations, iothub_registry_manager + ): + iothub_registry_manager.create_module_with_certificate_authority( + device_id=fake_device_id, module_id=fake_module_id, managed_by=fake_managed_by + ) + + assert mock_registry_manager_operations.create_or_update_module.call_count == 1 + assert ( + mock_registry_manager_operations.create_or_update_module.call_args[0][0] + == fake_device_id + ) + assert ( + mock_registry_manager_operations.create_or_update_module.call_args[0][1] + == fake_module_id + ) + assert ( + mock_registry_manager_operations.create_or_update_module.call_args[0][2] + == mock_module_constructor.return_value + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .update_module_with_sas()") +class TestUpdateModuleWithSymmetricKey(object): + + testdata = [(fake_primary_key, None), (None, fake_secondary_key)] + + @pytest.mark.it("Initializes module with device id, module id, managed_by and sas auth") + @pytest.mark.parametrize( + "primary_key, secondary_key", testdata, ids=["Primary Key", "Secondary Key"] + ) + def test_initializes_device_with_kwargs_for_sas( + self, iothub_registry_manager, mock_module_constructor, primary_key, secondary_key + ): + iothub_registry_manager.update_module_with_sas( + device_id=fake_device_id, + module_id=fake_module_id, + managed_by=fake_managed_by, + etag=fake_etag, + primary_key=primary_key, + secondary_key=secondary_key, + ) + + assert mock_module_constructor.call_count == 1 + + assert mock_module_constructor.call_args[1]["module_id"] == fake_module_id + assert mock_module_constructor.call_args[1]["managed_by"] == fake_managed_by + assert mock_module_constructor.call_args[1]["device_id"] == fake_device_id + assert isinstance( + mock_module_constructor.call_args[1]["authentication"], AuthenticationMechanism + ) + auth_mechanism = mock_module_constructor.call_args[1]["authentication"] + assert auth_mechanism.type == "sas" + assert auth_mechanism.x509_thumbprint is None + sym_key = auth_mechanism.symmetric_key + assert sym_key.primary_key == primary_key + assert sym_key.secondary_key == secondary_key + + @pytest.mark.it( + "Calls method from service operations with device id, module id and previously constructed module" + ) + @pytest.mark.parametrize( + "primary_key, secondary_key", testdata, ids=["Primary Key", "Secondary Key"] + ) + def test_calls_create_or_update_device_for_sas( + self, + mock_module_constructor, + mock_registry_manager_operations, + iothub_registry_manager, + primary_key, + secondary_key, + ): + iothub_registry_manager.update_module_with_sas( + device_id=fake_device_id, + module_id=fake_module_id, + etag=fake_etag, + managed_by=fake_managed_by, + primary_key=primary_key, + secondary_key=secondary_key, + ) + + assert mock_registry_manager_operations.create_or_update_module.call_count == 1 + assert ( + mock_registry_manager_operations.create_or_update_module.call_args[0][0] + == fake_device_id + ) + assert ( + mock_registry_manager_operations.create_or_update_module.call_args[0][1] + == fake_module_id + ) + assert ( + mock_registry_manager_operations.create_or_update_module.call_args[0][2] + == mock_module_constructor.return_value + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .update_module_with_x509()") +class TestUpdateModuleWithX509(object): + + testdata = [(fake_primary_thumbprint, None), (None, fake_secondary_thumbprint)] + + @pytest.mark.it("Initializes module with device id, module id, managed_by and X509 auth") + @pytest.mark.parametrize( + "primary_thumbprint, secondary_thumbprint", + testdata, + ids=["Primary Thumbprint", "Secondary Thumbprint"], + ) + def test_initializes_device_with_kwargs_for_x509( + self, + iothub_registry_manager, + mock_module_constructor, + primary_thumbprint, + secondary_thumbprint, + ): + iothub_registry_manager.update_module_with_x509( + device_id=fake_device_id, + module_id=fake_module_id, + etag=fake_etag, + managed_by=fake_managed_by, + primary_thumbprint=primary_thumbprint, + secondary_thumbprint=secondary_thumbprint, + ) + + assert mock_module_constructor.call_count == 1 + assert mock_module_constructor.call_args[1]["module_id"] == fake_module_id + assert mock_module_constructor.call_args[1]["managed_by"] == fake_managed_by + assert mock_module_constructor.call_args[1]["device_id"] == fake_device_id + assert isinstance( + mock_module_constructor.call_args[1]["authentication"], AuthenticationMechanism + ) + auth_mechanism = mock_module_constructor.call_args[1]["authentication"] + assert auth_mechanism.type == "selfSigned" + assert auth_mechanism.symmetric_key is None + x509_thumbprint = auth_mechanism.x509_thumbprint + assert x509_thumbprint.primary_thumbprint == primary_thumbprint + assert x509_thumbprint.secondary_thumbprint == secondary_thumbprint + + @pytest.mark.it( + "Calls method from service operations with device id, module id and previously constructed module" + ) + @pytest.mark.parametrize( + "primary_thumbprint, secondary_thumbprint", + testdata, + ids=["Primary Thumbprint", "Secondary Thumbprint"], + ) + def test_calls_create_or_update_device_for_x509( + self, + mock_module_constructor, + mock_registry_manager_operations, + iothub_registry_manager, + primary_thumbprint, + secondary_thumbprint, + ): + iothub_registry_manager.update_module_with_x509( + device_id=fake_device_id, + module_id=fake_module_id, + etag=fake_etag, + managed_by=fake_managed_by, + primary_thumbprint=primary_thumbprint, + secondary_thumbprint=secondary_thumbprint, + ) + + assert mock_registry_manager_operations.create_or_update_module.call_count == 1 + assert ( + mock_registry_manager_operations.create_or_update_module.call_args[0][0] + == fake_device_id + ) + assert ( + mock_registry_manager_operations.create_or_update_module.call_args[0][1] + == fake_module_id + ) + assert ( + mock_registry_manager_operations.create_or_update_module.call_args[0][2] + == mock_module_constructor.return_value + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .update_module_with_certificate_authority()") +class TestUpdateModuleWithCA(object): + @pytest.mark.it("Initializes module with device id, module id, managed_by and ca auth") + def test_initializes_device_with_kwargs_for_certificate_authority( + self, mock_module_constructor, iothub_registry_manager + ): + iothub_registry_manager.update_module_with_certificate_authority( + device_id=fake_device_id, + module_id=fake_module_id, + etag=fake_etag, + managed_by=fake_managed_by, + ) + + assert mock_module_constructor.call_count == 1 + assert mock_module_constructor.call_args[1]["module_id"] == fake_module_id + assert mock_module_constructor.call_args[1]["managed_by"] == fake_managed_by + assert mock_module_constructor.call_args[1]["device_id"] == fake_device_id + assert isinstance( + mock_module_constructor.call_args[1]["authentication"], AuthenticationMechanism + ) + auth_mechanism = mock_module_constructor.call_args[1]["authentication"] + assert auth_mechanism.type == "certificateAuthority" + assert auth_mechanism.x509_thumbprint is None + assert auth_mechanism.symmetric_key is None + + @pytest.mark.it( + "Calls method from service operations with device id, module id and previously constructed module" + ) + def test_calls_create_or_update_device_for_certificate_authority( + self, mock_module_constructor, mock_registry_manager_operations, iothub_registry_manager + ): + iothub_registry_manager.update_module_with_certificate_authority( + device_id=fake_device_id, + module_id=fake_module_id, + etag=fake_etag, + managed_by=fake_managed_by, + ) + + assert mock_registry_manager_operations.create_or_update_module.call_count == 1 + assert ( + mock_registry_manager_operations.create_or_update_module.call_args[0][0] + == fake_device_id + ) + assert ( + mock_registry_manager_operations.create_or_update_module.call_args[0][1] + == fake_module_id + ) + assert ( + mock_registry_manager_operations.create_or_update_module.call_args[0][2] + == mock_module_constructor.return_value + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .get_module()") +class TestGetModule(object): + @pytest.mark.it("Gets module from service for provided device id and module id") + def test_get_module(self, mocker, mock_registry_manager_operations, iothub_registry_manager): + iothub_registry_manager.get_module(fake_device_id, fake_module_id) + + assert mock_registry_manager_operations.get_module.call_count == 1 + assert mock_registry_manager_operations.get_module.call_args == mocker.call( + fake_device_id, fake_module_id + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .get_modules()") +class TestGetModules(object): + @pytest.mark.it("Gets all modules from service for provided device") + def test_get_module(self, mocker, mock_registry_manager_operations, iothub_registry_manager): + iothub_registry_manager.get_modules(fake_device_id) + + assert mock_registry_manager_operations.get_modules_on_device.call_count == 1 + assert mock_registry_manager_operations.get_modules_on_device.call_args == mocker.call( + fake_device_id + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .delete_module()") +class TestDeleteModule(object): + @pytest.mark.it("Deletes module for the provided device id") + def test_delete_module(self, mocker, mock_registry_manager_operations, iothub_registry_manager): + iothub_registry_manager.delete_module(fake_device_id, fake_module_id) + + assert mock_registry_manager_operations.delete_module.call_count == 1 + assert mock_registry_manager_operations.delete_module.call_args == mocker.call( + fake_device_id, fake_module_id, "*" + ) + + @pytest.mark.it("Deletes module with an etag for the provided device id and etag") + def test_delete_module_with_etag( + self, mocker, mock_registry_manager_operations, iothub_registry_manager + ): + iothub_registry_manager.delete_module( + device_id=fake_device_id, module_id=fake_module_id, etag=fake_etag + ) + + assert mock_registry_manager_operations.delete_module.call_count == 1 + assert mock_registry_manager_operations.delete_module.call_args == mocker.call( + fake_device_id, fake_module_id, fake_etag + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .get_service_statistics()") +class TestGetServiceStats(object): + @pytest.mark.it("Gets service statistics") + def test_get_service_statistics( + self, mocker, mock_registry_manager_operations, iothub_registry_manager + ): + iothub_registry_manager.get_service_statistics() + + assert mock_registry_manager_operations.get_service_statistics.call_count == 1 + assert mock_registry_manager_operations.get_service_statistics.call_args == mocker.call() + + +@pytest.mark.describe("IoTHubRegistryManager - .get_device_registry_statistics()") +class TestGetDeviceRegistryStats(object): + @pytest.mark.it("Gets device registry statistics") + def test_get_device_registry_statistics( + self, mocker, mock_registry_manager_operations, iothub_registry_manager + ): + iothub_registry_manager.get_device_registry_statistics() + + assert mock_registry_manager_operations.get_device_statistics.call_count == 1 + assert mock_registry_manager_operations.get_device_statistics.call_args == mocker.call() + + +@pytest.mark.describe("IoTHubRegistryManager - .get_devices()") +class TestGetDevices(object): + @pytest.mark.it("Gets devices") + def test_get_devices(self, mocker, mock_registry_manager_operations, iothub_registry_manager): + iothub_registry_manager.get_devices() + + assert mock_registry_manager_operations.get_devices.call_count == 1 + assert mock_registry_manager_operations.get_devices.call_args == mocker.call(None) + + +@pytest.mark.describe("IoTHubRegistryManager - .get_devices(max_number_of_devices)") +class TestGetDevicesWithMax(object): + @pytest.mark.it("Gets devices with max_number_of_devices") + def test_get_devices(self, mocker, mock_registry_manager_operations, iothub_registry_manager): + max_number_of_devices = 42 + iothub_registry_manager.get_devices(max_number_of_devices) + + assert mock_registry_manager_operations.get_devices.call_count == 1 + assert mock_registry_manager_operations.get_devices.call_args == mocker.call( + max_number_of_devices + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .bulk_create_or_update_devices()") +class TestBulkCreateUpdateDevices(object): + @pytest.mark.it("Test bulk_create_or_update_devices") + def test_bulk_create_or_update_devices( + self, mocker, mock_registry_manager_operations, iothub_registry_manager + ): + iothub_registry_manager.bulk_create_or_update_devices(fake_devices) + assert mock_registry_manager_operations.bulk_device_crud.call_count == 1 + assert mock_registry_manager_operations.bulk_device_crud.call_args == mocker.call( + fake_devices + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .query_iot_hub()") +class TestQueryIoTHub(object): + @pytest.mark.it("Test query IoTHub") + def test_query_iot_hub(self, mocker, mock_registry_manager_operations, iothub_registry_manager): + iothub_registry_manager.query_iot_hub(fake_query_specification) + assert mock_registry_manager_operations.query_iot_hub.call_count == 1 + assert mock_registry_manager_operations.query_iot_hub.call_args == mocker.call( + fake_query_specification, None, None, None, True + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .query_iot_hub(continuation_token)") +class TestQueryIoTHubWithContinuationToken(object): + @pytest.mark.it("Test query IoTHub with continuation token") + def test_query_iot_hub(self, mocker, mock_registry_manager_operations, iothub_registry_manager): + continuation_token = 42 + iothub_registry_manager.query_iot_hub(fake_query_specification, continuation_token) + assert mock_registry_manager_operations.query_iot_hub.call_count == 1 + assert mock_registry_manager_operations.query_iot_hub.call_args == mocker.call( + fake_query_specification, continuation_token, None, None, True + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .query_iot_hub(continuation_token, max_item_count)") +class TestQueryIoTHubWithContinuationTokenAndMaxItermCount(object): + @pytest.mark.it("Test query IoTHub with continuation token and max item count") + def test_query_iot_hub(self, mocker, mock_registry_manager_operations, iothub_registry_manager): + continuation_token = 42 + max_item_count = 84 + iothub_registry_manager.query_iot_hub( + fake_query_specification, continuation_token, max_item_count + ) + assert mock_registry_manager_operations.query_iot_hub.call_count == 1 + assert mock_registry_manager_operations.query_iot_hub.call_args == mocker.call( + fake_query_specification, continuation_token, max_item_count, None, True + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .get_twin()") +class TestGetTwin(object): + @pytest.mark.it("Test get twin") + def test_get_twin(self, mocker, mock_twin_operations, iothub_registry_manager): + iothub_registry_manager.get_twin(fake_device_id) + assert mock_twin_operations.get_device_twin.call_count == 1 + assert mock_twin_operations.get_device_twin.call_args == mocker.call(fake_device_id) + + +@pytest.mark.describe("IoTHubRegistryManager - .replace_twin()") +class TestReplaceTwin(object): + @pytest.mark.it("Test replace twin") + def test_replace_twin(self, mocker, mock_twin_operations, iothub_registry_manager): + iothub_registry_manager.replace_twin(fake_device_id, fake_device_twin) + assert mock_twin_operations.replace_device_twin.call_count == 1 + assert mock_twin_operations.replace_device_twin.call_args == mocker.call( + fake_device_id, fake_device_twin + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .update_twin()") +class TestUpdateTwin(object): + @pytest.mark.it("Test update twin") + def test_update_twin(self, mocker, mock_twin_operations, iothub_registry_manager): + iothub_registry_manager.update_twin(fake_device_id, fake_device_twin, fake_etag) + assert mock_twin_operations.update_device_twin.call_count == 1 + assert mock_twin_operations.update_device_twin.call_args == mocker.call( + fake_device_id, fake_device_twin, fake_etag + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .get_module_twin()") +class TestGetModuleTwin(object): + @pytest.mark.it("Test get module twin") + def test_get_module_twin(self, mocker, mock_twin_operations, iothub_registry_manager): + iothub_registry_manager.get_module_twin(fake_device_id, fake_module_id) + assert mock_twin_operations.get_module_twin.call_count == 1 + assert mock_twin_operations.get_module_twin.call_args == mocker.call( + fake_device_id, fake_module_id + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .replace_module_twin()") +class TestReplaceModuleTwin(object): + @pytest.mark.it("Test replace module twin") + def test_replace_module_twin(self, mocker, mock_twin_operations, iothub_registry_manager): + iothub_registry_manager.replace_module_twin( + fake_device_id, fake_module_id, fake_module_twin + ) + assert mock_twin_operations.replace_module_twin.call_count == 1 + assert mock_twin_operations.replace_module_twin.call_args == mocker.call( + fake_device_id, fake_module_id, fake_module_twin + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .update_module_twin()") +class TestUpdateModuleTwin(object): + @pytest.mark.it("Test update module twin") + def test_update_module_twin(self, mocker, mock_twin_operations, iothub_registry_manager): + iothub_registry_manager.update_module_twin( + fake_device_id, fake_module_id, fake_module_twin, fake_etag + ) + assert mock_twin_operations.update_module_twin.call_count == 1 + assert mock_twin_operations.update_module_twin.call_args == mocker.call( + fake_device_id, fake_module_id, fake_module_twin, fake_etag + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .invoke_device_method()") +class TestInvokeDeviceMethod(object): + @pytest.mark.it("Test invoke device method") + def test_invoke_device_method( + self, mocker, mock_device_method_operations, iothub_registry_manager + ): + iothub_registry_manager.invoke_device_method(fake_device_id, fake_direct_method_request) + assert mock_device_method_operations.invoke_device_method.call_count == 1 + assert mock_device_method_operations.invoke_device_method.call_args == mocker.call( + fake_device_id, fake_direct_method_request + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .invoke_device_module_method()") +class TestInvokeDeviceModuleMethod(object): + @pytest.mark.it("Test invoke device module method") + def test_invoke_device_module_method( + self, mocker, mock_device_method_operations, iothub_registry_manager + ): + iothub_registry_manager.invoke_device_module_method( + fake_device_id, fake_module_id, fake_direct_method_request + ) + assert mock_device_method_operations.invoke_module_method.call_count == 1 + assert mock_device_method_operations.invoke_module_method.call_args == mocker.call( + fake_device_id, fake_module_id, fake_direct_method_request + ) + + +@pytest.mark.describe("IoTHubRegistryManager - .send_c2d_message()") +class TestSendC2dMessage(object): + @pytest.mark.it("Test send c2d message") + def test_send_c2d_message( + self, mocker, mock_uamqp_send_message_to_device, iothub_registry_manager + ): + + iothub_registry_manager.send_c2d_message(fake_device_id, fake_message_to_send) + + assert mock_uamqp_send_message_to_device.call_count == 1 + assert mock_uamqp_send_message_to_device.call_args == mocker.call( + fake_device_id, fake_message_to_send + ) diff --git a/azure-iot-hub/tests/test_job_manager.py b/azure-iot-hub/tests/test_job_manager.py new file mode 100644 index 000000000..249bea9b4 --- /dev/null +++ b/azure-iot-hub/tests/test_job_manager.py @@ -0,0 +1,126 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import pytest +from azure.iot.hub.protocol.models import AuthenticationMechanism +from azure.iot.hub.iothub_job_manager import IoTHubJobManager + +"""---Constants---""" + +fake_hostname = "beauxbatons.academy-net" +fake_device_id = "MyPensieve" +fake_shared_access_key_name = "alohomora" +fake_shared_access_key = "Zm9vYmFy" +fake_job_properties = "fake_job_properties" +fake_job_id = "fake_job_id" +fake_job_request = "fake_job_request" +fake_job_type = "fake_job_type" +fake_job_status = "fake_job_status" + + +"""----Shared fixtures----""" + + +@pytest.fixture(scope="function", autouse=True) +def mock_job_client_operations(mocker): + mock_job_client_operations_init = mocker.patch( + "azure.iot.hub.protocol.iot_hub_gateway_service_ap_is.JobClientOperations" + ) + return mock_job_client_operations_init.return_value + + +@pytest.fixture(scope="function") +def iothub_job_manager(): + connection_string = "HostName={hostname};DeviceId={device_id};SharedAccessKeyName={skn};SharedAccessKey={sk}".format( + hostname=fake_hostname, + device_id=fake_device_id, + skn=fake_shared_access_key_name, + sk=fake_shared_access_key, + ) + iothub_job_manager = IoTHubJobManager(connection_string) + return iothub_job_manager + + +@pytest.mark.describe("IoTHubJobManager - .create_import_export_job()") +class TestCreateImportExportJob(object): + @pytest.mark.it("Creates export/import job") + def test_create_export_import_job(self, mocker, mock_job_client_operations, iothub_job_manager): + iothub_job_manager.create_import_export_job(fake_job_properties) + assert mock_job_client_operations.create_import_export_job.call_count == 1 + assert mock_job_client_operations.create_import_export_job.call_args == mocker.call( + fake_job_properties + ) + + +@pytest.mark.describe("IoTHubJobManager - .get_import_export_jobs()") +class TestGetImportExportJobs(object): + @pytest.mark.it("Get export/import jobs") + def test_get_export_import_jobs(self, mocker, mock_job_client_operations, iothub_job_manager): + iothub_job_manager.get_import_export_jobs() + assert mock_job_client_operations.get_import_export_jobs.call_count == 1 + assert mock_job_client_operations.get_import_export_jobs.call_args == mocker.call() + + +@pytest.mark.describe("IoTHubJobManager - .get_import_export_job()") +class TestGetImportExportJob(object): + @pytest.mark.it("Get export/import job") + def test_get_export_import_job(self, mocker, mock_job_client_operations, iothub_job_manager): + iothub_job_manager.get_import_export_job(fake_job_id) + assert mock_job_client_operations.get_import_export_job.call_count == 1 + assert mock_job_client_operations.get_import_export_job.call_args == mocker.call( + fake_job_id + ) + + +@pytest.mark.describe("IoTHubJobManager - .cancel_import_export_job()") +class TestCancelImportExportJob(object): + @pytest.mark.it("Cancel export/import job") + def test_cancel_import_export_job(self, mocker, mock_job_client_operations, iothub_job_manager): + iothub_job_manager.cancel_import_export_job(fake_job_id) + assert mock_job_client_operations.cancel_import_export_job.call_count == 1 + assert mock_job_client_operations.cancel_import_export_job.call_args == mocker.call( + fake_job_id + ) + + +@pytest.mark.describe("IoTHubJobManager - .create_job()") +class TestCreateJob(object): + @pytest.mark.it("Create job") + def test_create_job(self, mocker, mock_job_client_operations, iothub_job_manager): + iothub_job_manager.create_job(fake_job_id, fake_job_request) + assert mock_job_client_operations.create_job.call_count == 1 + assert mock_job_client_operations.create_job.call_args == mocker.call( + fake_job_id, fake_job_request + ) + + +@pytest.mark.describe("IoTHubJobManager - .get_job()") +class TestGetJob(object): + @pytest.mark.it("Get job") + def test_get_job(self, mocker, mock_job_client_operations, iothub_job_manager): + iothub_job_manager.get_job(fake_job_id) + assert mock_job_client_operations.get_job.call_count == 1 + assert mock_job_client_operations.get_job.call_args == mocker.call(fake_job_id) + + +@pytest.mark.describe("IoTHubJobManager - .cancel_job()") +class TestCancelJob(object): + @pytest.mark.it("Cancel job") + def test_get_job(self, mocker, mock_job_client_operations, iothub_job_manager): + iothub_job_manager.cancel_job(fake_job_id) + assert mock_job_client_operations.cancel_job.call_count == 1 + assert mock_job_client_operations.cancel_job.call_args == mocker.call(fake_job_id) + + +@pytest.mark.describe("IoTHubJobManager - .query_jobs()") +class TestQueryJob(object): + @pytest.mark.it("Query job") + def test_get_job(self, mocker, mock_job_client_operations, iothub_job_manager): + iothub_job_manager.query_jobs(fake_job_type, fake_job_status) + assert mock_job_client_operations.query_jobs.call_count == 1 + assert mock_job_client_operations.query_jobs.call_args == mocker.call( + fake_job_type, fake_job_status + ) diff --git a/azure-iot-nspkg/setup.py b/azure-iot-nspkg/setup.py index 648c7f6ad..9c1b41a03 100644 --- a/azure-iot-nspkg/setup.py +++ b/azure-iot-nspkg/setup.py @@ -46,7 +46,7 @@ setup( license="MIT License", author="Microsoft Corporation", author_email="opensource@microsoft.com", - url="https://github.com/Azure/azure-iot-sdk-python-preview", + url="https://github.com/Azure/azure-iot-sdk-python/tree/master/azure-iot-nspkg", classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", diff --git a/azure_iot_hub_e2e/tests/__init__.py b/azure_iot_hub_e2e/tests/__init__.py new file mode 100644 index 000000000..8db66d3d0 --- /dev/null +++ b/azure_iot_hub_e2e/tests/__init__.py @@ -0,0 +1 @@ +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/azure_iot_hub_e2e/tests/pytest.ini b/azure_iot_hub_e2e/tests/pytest.ini new file mode 100644 index 000000000..2f799d7d7 --- /dev/null +++ b/azure_iot_hub_e2e/tests/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +addopts = --timeout 30 \ No newline at end of file diff --git a/azure_iot_hub_e2e/tests/test_iothub_configuration_manager_e2e.py b/azure_iot_hub_e2e/tests/test_iothub_configuration_manager_e2e.py new file mode 100644 index 000000000..2798b8640 --- /dev/null +++ b/azure_iot_hub_e2e/tests/test_iothub_configuration_manager_e2e.py @@ -0,0 +1,74 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import pytest +import logging +import uuid +from azure.iot.hub.iothub_configuration_manager import IoTHubConfigurationManager +from azure.iot.hub.models import Configuration, ConfigurationContent, ConfigurationMetrics + +logging.basicConfig(level=logging.DEBUG) + +iothub_connection_str = os.getenv("IOTHUB_CONNECTION_STRING") + + +@pytest.mark.describe("Create and test IoTHubConfigurationManager") +class TestConfigurationManager(object): + @pytest.mark.it("Create IoTHubConfigurationManager and create, get and delete configuration") + def test_iot_hub_configuration_manager(self): + try: + iothub_configuration = IoTHubConfigurationManager(iothub_connection_str) + + # Create configuration + config_id = "e2e_test_config-" + str(uuid.uuid4()) + + config = Configuration() + config.id = config_id + + content = ConfigurationContent( + device_content={ + "properties.desired.chiller-water": {"temperature: 68, pressure:28"} + } + ) + config.content = content + + metrics = ConfigurationMetrics( + queries={ + "waterSettingPending": "SELECT deviceId FROM devices WHERE properties.reported.chillerWaterSettings.status='pending'" + } + ) + config.metrics = metrics + + # Create configuration + new_config = iothub_configuration.create_configuration(config) + + # Verify result + assert new_config.id == config_id + + # Get configuration + get_config = iothub_configuration.get_configuration(config_id) + + # Verify result + assert get_config.id == config_id + + # Get all configurations + all_configurations = iothub_configuration.get_configurations() + + # Verify result + assert get_config in all_configurations + + # Delete configuration + iothub_configuration.delete_configuration(config_id) + + # Get all configurations + all_configurations = iothub_configuration.get_configurations() + + # # Verify result + assert get_config not in all_configurations + + except Exception as e: + logging.exception(e) diff --git a/azure_iot_hub_e2e/tests/test_iothub_job_manager_e2e.py b/azure_iot_hub_e2e/tests/test_iothub_job_manager_e2e.py new file mode 100644 index 000000000..1c2193787 --- /dev/null +++ b/azure_iot_hub_e2e/tests/test_iothub_job_manager_e2e.py @@ -0,0 +1,111 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import pytest +import logging +import uuid +from azure.iot.hub.iothub_job_manager import IoTHubJobManager +from azure.iot.hub.models import JobProperties, JobRequest + +logging.basicConfig(level=logging.DEBUG) + +iothub_connection_str = os.getenv("IOTHUB_CONNECTION_STRING") +output_container_uri = os.getenv("JOB_EXPORT_IMPORT_OUTPUT_URI") + + +@pytest.mark.describe("Create and test IoTHubJobManager") +class TestJobManager(object): + @pytest.mark.it("Create IoTHubJobManager client and create, get and cancel export/import job") + def test_iot_hub_job_manager_export_import_jobs(self): + try: + iothub_job_manager = IoTHubJobManager(iothub_connection_str) + + # Create export/import job + storage_authentication_type = "keyBased" + properties_type = "export" + job_properties = JobProperties() + job_properties.storage_authentication_type = storage_authentication_type + job_properties.type = properties_type + job_properties.output_blob_container_uri = output_container_uri + + new_export_import_job = iothub_job_manager.create_import_export_job(job_properties) + + # Verify result + assert new_export_import_job.storage_authentication_type == storage_authentication_type + assert new_export_import_job.type == properties_type + assert new_export_import_job.output_blob_container_uri == output_container_uri + + # Get export/import job + get_export_import_job = iothub_job_manager.get_import_export_job( + new_export_import_job.job_id + ) + + # Verify result + assert get_export_import_job.job_id == new_export_import_job.job_id + assert ( + get_export_import_job.storage_authentication_type + == new_export_import_job.storage_authentication_type + ) + assert get_export_import_job.type == new_export_import_job.type + assert ( + get_export_import_job.output_blob_container_uri + == new_export_import_job.output_blob_container_uri + ) + + # Get all export/import jobs + export_import_jobs = iothub_job_manager.get_import_export_jobs() + + assert new_export_import_job in export_import_jobs + + # Cancel export_import job + iothub_job_manager.cancel_import_export_job(new_export_import_job.job_id) + + # Get all export/import jobs + export_import_jobs = iothub_job_manager.get_import_export_jobs() + + assert new_export_import_job not in export_import_jobs + + except Exception as e: + logging.exception(e) + + @pytest.mark.it("Create IoTHubJobManager client and create, get and cancel job") + def test_iot_hub_job_manager_jobs(self): + try: + iothub_job_manager = IoTHubJobManager(iothub_connection_str) + + # Create job request + job_id = "sample_cloud_to_device_method" + job_type = "cloudToDeviceMethod" + job_execution_time_max = 60 + job_request = JobRequest() + job_request.job_id = job_id + job_request.type = job_type + job_request.start_time = "" + job_request.max_execution_time_in_seconds = job_execution_time_max + job_request.update_twin = "" + job_request.query_condition = "" + + new_job_response = iothub_job_manager.create_job(job_request.job_id, job_request) + + # Verify result + assert new_job_response.job_id == job_type + assert new_job_response.type == job_type + assert new_job_response.max_execution_time_in_seconds == job_execution_time_max + + # Get job + get_job = iothub_job_manager.get_job(new_job_response.job_id) + + # Verify result + assert get_job.job_id == job_id + assert get_job.type == job_type + assert get_job.max_execution_time_in_seconds == job_execution_time_max + + # Cancel job + iothub_job_manager.cancel_job(get_job.job_id) + + except Exception as e: + logging.exception(e) diff --git a/azure_iot_hub_e2e/tests/test_iothub_registry_manager_e2e.py b/azure_iot_hub_e2e/tests/test_iothub_registry_manager_e2e.py new file mode 100644 index 000000000..78e2b7e76 --- /dev/null +++ b/azure_iot_hub_e2e/tests/test_iothub_registry_manager_e2e.py @@ -0,0 +1,161 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import pytest +import logging +import uuid +from azure.iot.hub.iothub_registry_manager import IoTHubRegistryManager + +logging.basicConfig(level=logging.DEBUG) + +iothub_connection_str = os.getenv("IOTHUB_CONNECTION_STRING") + + +@pytest.mark.describe("Create and test IoTHubRegistryManager") +class TestRegistryManager(object): + @pytest.mark.it( + "Create IoTHubRegistryManager using SAS authentication and create, get and delete device" + ) + def test_iot_hub_registry_manager_sas(self): + device_id = "e2e-iot-hub-registry-manager-sas-" + str(uuid.uuid4()) + + try: + iothub_registry_manager = IoTHubRegistryManager(iothub_connection_str) + + # Create a device + primary_key = "aaabbbcccdddeeefffggghhhiiijjjkkklllmmmnnnoo" + secondary_key = "111222333444555666777888999000aaabbbcccdddee" + device_state = "enabled" + new_device = iothub_registry_manager.create_device_with_sas( + device_id, primary_key, secondary_key, device_state + ) + + # Verify result + assert new_device.device_id == device_id + assert new_device.authentication.type == "sas" + assert new_device.authentication.symmetric_key.primary_key == primary_key + assert new_device.authentication.symmetric_key.secondary_key == secondary_key + assert new_device.status == device_state + + # Delete device + iothub_registry_manager.delete_device(device_id) + + except Exception as e: + logging.exception(e) + + @pytest.mark.it("Create, get, update and delete device") + @pytest.mark.describe("Create and test IoTHubRegistryManager device CRUD") + def test_iot_hub_registry_manager_sas_crud(self): + device_id = "e2e-iot-hub-registry-manager-sas-" + str(uuid.uuid4()) + + try: + iothub_registry_manager = IoTHubRegistryManager(iothub_connection_str) + + # Create a device + primary_key = "aaabbbcccdddeeefffggghhhiiijjjkkklllmmmnnnoo" + secondary_key = "111222333444555666777888999000aaabbbcccdddee" + device_state = "enabled" + new_device = iothub_registry_manager.create_device_with_sas( + device_id, primary_key, secondary_key, device_state + ) + + # Verify result + assert new_device.device_id == device_id + assert new_device.authentication.type == "sas" + assert new_device.authentication.symmetric_key.primary_key == primary_key + assert new_device.authentication.symmetric_key.secondary_key == secondary_key + assert new_device.status == device_state + + # Update device + updated_status = "disabled" + updated_device = iothub_registry_manager.update_device_with_sas( + device_id, new_device.etag, primary_key, secondary_key, updated_status + ) + + # Verify result + assert updated_device.status == updated_status + + # Delete device + iothub_registry_manager.delete_device(device_id) + + except Exception as e: + logging.exception(e) + + @pytest.mark.it("Create, get, update and delete module") + @pytest.mark.describe("Create and test IoTHubRegistryManager module CRUD") + def test_iot_hub_registry_manager_sas_module_crud(self): + device_id = "e2e-iot-hub-registry-manager-sas-" + str(uuid.uuid4()) + module_id = "e2e-iot-hub-registry-manager-sas-module-" + str(uuid.uuid4()) + + try: + iothub_registry_manager = IoTHubRegistryManager(iothub_connection_str) + + # Create a device + primary_key = "aaabbbcccdddeeefffggghhhiiijjjkkklllmmmnnnoo" + secondary_key = "111222333444555666777888999000aaabbbcccdddee" + device_state = "enabled" + new_device = iothub_registry_manager.create_device_with_sas( + device_id, primary_key, secondary_key, device_state + ) + + # Create module + module_primary_key = "hhhiiijjjkkklllmmmnnnooaaabbbcccdddeeefffggg" + module_secondary_key = "888999000aaabbbcccdddee111222333444555666777" + managed_by = device_id + new_module = iothub_registry_manager.create_module_with_sas( + device_id, module_id, managed_by, module_primary_key, module_secondary_key + ) + + # Verify result + assert new_device.device_id == device_id + assert new_device.authentication.symmetric_key.primary_key == primary_key + assert new_device.authentication.symmetric_key.secondary_key == secondary_key + assert new_device.status == device_state + + assert new_module.module_id == module_id + assert new_module.managed_by == device_id + assert new_module.authentication.type == "sas" + assert new_module.authentication.symmetric_key.primary_key == module_primary_key + assert new_module.authentication.symmetric_key.secondary_key == module_secondary_key + + # Get modules + one_module = iothub_registry_manager.get_modules(device_id) + assert len(one_module) == 1 + + # Update module + update_module_primary_key = "jjjkkklllmmmnnnooaaahhhiiibbbcccdddeeefffggg" + update_module_secondary_key = "000aaabbbcccdddee888999111222333444555666777" + updated_module = iothub_registry_manager.update_module_with_sas( + device_id, + module_id, + managed_by, + new_module.etag, + update_module_primary_key, + update_module_secondary_key, + ) + + # Verify result + assert ( + updated_module.authentication.symmetric_key.primary_key == update_module_primary_key + ) + assert ( + updated_module.authentication.symmetric_key.secondary_key + == update_module_secondary_key + ) + + # Delete module + iothub_registry_manager.delete_module(device_id, module_id) + + # Verify result + no_module = iothub_registry_manager.get_modules(device_id) + assert len(no_module) == 0 + + # Delete device + iothub_registry_manager.delete_device(device_id) + + except Exception as e: + logging.exception(e) diff --git a/azure_provisioning_e2e/tests/test_async_certificate_enrollments.py b/azure_provisioning_e2e/tests/test_async_certificate_enrollments.py index 6711cf3c3..3893b773f 100644 --- a/azure_provisioning_e2e/tests/test_async_certificate_enrollments.py +++ b/azure_provisioning_e2e/tests/test_async_certificate_enrollments.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- + from azure_provisioning_e2e.service_helper import Helper, connection_string_to_hostname from azure.iot.device.aio import ProvisioningDeviceClient from azure.iot.device.common import X509 @@ -19,8 +20,9 @@ import os import uuid from scripts.create_x509_chain_pipeline import ( + before_cert_creation_from_pipeline, call_intermediate_cert_creation_from_pipeline, - call_device_cert_creation_from_pipeline, + create_device_certs, delete_directories_certs_created_from_pipeline, ) @@ -55,12 +57,13 @@ type_to_device_indices = { @pytest.fixture(scope="module", autouse=True) def before_all_tests(request): logging.info("set up certificates before cert related tests") + before_cert_creation_from_pipeline() call_intermediate_cert_creation_from_pipeline( common_name=intermediate_common_name, ca_password=os.getenv("PROVISIONING_ROOT_PASSWORD"), intermediate_password=intermediate_password, ) - call_device_cert_creation_from_pipeline( + create_device_certs( common_name=device_common_name, intermediate_password=intermediate_password, device_password=device_password, @@ -75,11 +78,10 @@ def before_all_tests(request): @pytest.mark.it( - "A device gets provisioned to the linked IoTHub with the user supplied device_id different " - "from the registration_id of the individual enrollment that has been created with a " - "selfsigned X509 authentication" + "A device gets provisioned to the linked IoTHub with the user supplied device_id different from the registration_id of the individual enrollment that has been created with a selfsigned X509 authentication" ) -async def test_device_register_with_device_id_for_a_x509_individual_enrollment(): +@pytest.mark.parametrize("protocol", ["mqtt", "mqttws"]) +async def test_device_register_with_device_id_for_a_x509_individual_enrollment(protocol): device_id = "e2edpsthunderbolt" device_index = type_to_device_indices.get("individual_with_device_id")[0] @@ -92,7 +94,7 @@ async def test_device_register_with_device_id_for_a_x509_individual_enrollment() device_cert_file = "demoCA/newcerts/device_cert" + str(device_index) + ".pem" device_key_file = "demoCA/private/device_key" + str(device_index) + ".pem" registration_result = await result_from_register( - registration_id, device_cert_file, device_key_file + registration_id, device_cert_file, device_key_file, protocol ) assert device_id != registration_id @@ -105,7 +107,8 @@ async def test_device_register_with_device_id_for_a_x509_individual_enrollment() @pytest.mark.it( "A device gets provisioned to the linked IoTHub with device_id equal to the registration_id of the individual enrollment that has been created with a selfsigned X509 authentication" ) -async def test_device_register_with_no_device_id_for_a_x509_individual_enrollment(): +@pytest.mark.parametrize("protocol", ["mqtt", "mqttws"]) +async def test_device_register_with_no_device_id_for_a_x509_individual_enrollment(protocol): device_index = type_to_device_indices.get("individual_no_device_id")[0] try: @@ -117,7 +120,7 @@ async def test_device_register_with_no_device_id_for_a_x509_individual_enrollmen device_cert_file = "demoCA/newcerts/device_cert" + str(device_index) + ".pem" device_key_file = "demoCA/private/device_key" + str(device_index) + ".pem" registration_result = await result_from_register( - registration_id, device_cert_file, device_key_file + registration_id, device_cert_file, device_key_file, protocol ) assert_device_provisioned( @@ -131,7 +134,10 @@ async def test_device_register_with_no_device_id_for_a_x509_individual_enrollmen @pytest.mark.it( "A group of devices get provisioned to the linked IoTHub with device_ids equal to the individual registration_ids inside a group enrollment that has been created with intermediate X509 authentication" ) -async def test_group_of_devices_register_with_no_device_id_for_a_x509_intermediate_authentication_group_enrollment(): +@pytest.mark.parametrize("protocol", ["mqtt", "mqttws"]) +async def test_group_of_devices_register_with_no_device_id_for_a_x509_intermediate_authentication_group_enrollment( + protocol +): group_id = "e2e-intermediate-durmstrang" + str(uuid.uuid4()) common_device_id = device_common_name devices_indices = type_to_device_indices.get("group_intermediate") @@ -172,6 +178,7 @@ async def test_group_of_devices_register_with_no_device_id_for_a_x509_intermedia registration_id=device_id, device_cert_file=device_inter_cert_chain_file, device_key_file=device_key_input_file, + protocol=protocol, ) assert_device_provisioned(device_id=device_id, registration_result=registration_result) @@ -190,7 +197,10 @@ async def test_group_of_devices_register_with_no_device_id_for_a_x509_intermedia @pytest.mark.it( "A group of devices get provisioned to the linked IoTHub with device_ids equal to the individual registration_ids inside a group enrollment that has been created with an already uploaded ca cert X509 authentication" ) -async def test_group_of_devices_register_with_no_device_id_for_a_x509_ca_authentication_group_enrollment(): +@pytest.mark.parametrize("protocol", ["mqtt", "mqttws"]) +async def test_group_of_devices_register_with_no_device_id_for_a_x509_ca_authentication_group_enrollment( + protocol +): group_id = "e2e-ca-ilvermorny" + str(uuid.uuid4()) common_device_id = device_common_name devices_indices = type_to_device_indices.get("group_ca") @@ -232,6 +242,7 @@ async def test_group_of_devices_register_with_no_device_id_for_a_x509_ca_authent registration_id=device_id, device_cert_file=device_inter_cert_chain_file, device_key_file=device_key_input_file, + protocol=protocol, ) assert_device_provisioned(device_id=device_id, registration_result=registration_result) @@ -279,14 +290,15 @@ def create_individual_enrollment_with_x509_client_certs(device_index, device_id= return service_client.create_or_update(individual_provisioning_model) -async def result_from_register(registration_id, device_cert_file, device_key_file): +async def result_from_register(registration_id, device_cert_file, device_key_file, protocol): x509 = X509(cert_file=device_cert_file, key_file=device_key_file, pass_phrase=device_password) - + protocol_boolean_mapping = {"mqtt": False, "mqttws": True} provisioning_device_client = ProvisioningDeviceClient.create_from_x509_certificate( provisioning_host=PROVISIONING_HOST, registration_id=registration_id, id_scope=ID_SCOPE, x509=x509, + websockets=protocol_boolean_mapping[protocol], ) return await provisioning_device_client.register() diff --git a/azure_provisioning_e2e/tests/test_async_symmetric_enrollments.py b/azure_provisioning_e2e/tests/test_async_symmetric_enrollments.py index 12a497ba4..2a320aed4 100644 --- a/azure_provisioning_e2e/tests/test_async_symmetric_enrollments.py +++ b/azure_provisioning_e2e/tests/test_async_symmetric_enrollments.py @@ -32,7 +32,10 @@ linked_iot_hub = connection_string_to_hostname(os.getenv("IOTHUB_CONNECTION_STRI "A device gets provisioned to the linked IoTHub with the device_id equal to the registration_id" "of the individual enrollment that has been created with a symmetric key authentication" ) -async def test_device_register_with_no_device_id_for_a_symmetric_key_individual_enrollment(): +@pytest.mark.parametrize("protocol", ["mqtt", "mqttws"]) +async def test_device_register_with_no_device_id_for_a_symmetric_key_individual_enrollment( + protocol +): try: individual_enrollment_record = create_individual_enrollment( "e2e-dps-legilimens" + str(uuid.uuid4()) @@ -41,7 +44,7 @@ async def test_device_register_with_no_device_id_for_a_symmetric_key_individual_ registration_id = individual_enrollment_record.registration_id symmetric_key = individual_enrollment_record.attestation.symmetric_key.primary_key - registration_result = await result_from_register(registration_id, symmetric_key) + registration_result = await result_from_register(registration_id, symmetric_key, protocol) assert_device_provisioned( device_id=registration_id, registration_result=registration_result @@ -54,7 +57,8 @@ async def test_device_register_with_no_device_id_for_a_symmetric_key_individual_ @pytest.mark.it( "A device gets provisioned to the linked IoTHub with the user supplied device_id different from the registration_id of the individual enrollment that has been created with a symmetric key authentication" ) -async def test_device_register_with_device_id_for_a_symmetric_key_individual_enrollment(): +@pytest.mark.parametrize("protocol", ["mqtt", "mqttws"]) +async def test_device_register_with_device_id_for_a_symmetric_key_individual_enrollment(protocol): device_id = "e2edpsgoldensnitch" try: @@ -65,7 +69,7 @@ async def test_device_register_with_device_id_for_a_symmetric_key_individual_enr registration_id = individual_enrollment_record.registration_id symmetric_key = individual_enrollment_record.attestation.symmetric_key.primary_key - registration_result = await result_from_register(registration_id, symmetric_key) + registration_result = await result_from_register(registration_id, symmetric_key, protocol) assert device_id != registration_id assert_device_provisioned(device_id=device_id, registration_result=registration_result) @@ -111,12 +115,16 @@ def assert_device_provisioned(device_id, registration_result): # TODO Eventually should return result after the APi changes -async def result_from_register(registration_id, symmetric_key): +async def result_from_register(registration_id, symmetric_key, protocol): + # We have this mapping because the pytest logs look better with "mqtt" and "mqttws" + # instead of just "True" and "False". + protocol_boolean_mapping = {"mqtt": False, "mqttws": True} provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key( provisioning_host=PROVISIONING_HOST, registration_id=registration_id, id_scope=ID_SCOPE, symmetric_key=symmetric_key, + websockets=protocol_boolean_mapping[protocol], ) return await provisioning_device_client.register() diff --git a/azure_provisioning_e2e/tests/test_sync_certificate_enrollments.py b/azure_provisioning_e2e/tests/test_sync_certificate_enrollments.py index 7230c3d12..5a2d65e68 100644 --- a/azure_provisioning_e2e/tests/test_sync_certificate_enrollments.py +++ b/azure_provisioning_e2e/tests/test_sync_certificate_enrollments.py @@ -19,8 +19,9 @@ import os import uuid from scripts.create_x509_chain_pipeline import ( + before_cert_creation_from_pipeline, call_intermediate_cert_creation_from_pipeline, - call_device_cert_creation_from_pipeline, + create_device_certs, delete_directories_certs_created_from_pipeline, ) @@ -54,12 +55,13 @@ type_to_device_indices = { @pytest.fixture(scope="module", autouse=True) def before_all_tests(request): logging.info("set up certificates before cert related tests") + before_cert_creation_from_pipeline() call_intermediate_cert_creation_from_pipeline( common_name=intermediate_common_name, ca_password=os.getenv("PROVISIONING_ROOT_PASSWORD"), intermediate_password=intermediate_password, ) - call_device_cert_creation_from_pipeline( + create_device_certs( common_name=device_common_name, intermediate_password=intermediate_password, device_password=device_password, @@ -76,7 +78,8 @@ def before_all_tests(request): @pytest.mark.it( "A device gets provisioned to the linked IoTHub with the user supplied device_id different from the registration_id of the individual enrollment that has been created with a selfsigned X509 authentication" ) -def test_device_register_with_device_id_for_a_x509_individual_enrollment(): +@pytest.mark.parametrize("protocol", ["mqtt", "mqttws"]) +def test_device_register_with_device_id_for_a_x509_individual_enrollment(protocol): device_id = "e2edpsflyingfeather" device_index = type_to_device_indices.get("individual_with_device_id")[0] @@ -89,7 +92,7 @@ def test_device_register_with_device_id_for_a_x509_individual_enrollment(): device_cert_file = "demoCA/newcerts/device_cert" + str(device_index) + ".pem" device_key_file = "demoCA/private/device_key" + str(device_index) + ".pem" registration_result = result_from_register( - registration_id, device_cert_file, device_key_file + registration_id, device_cert_file, device_key_file, protocol ) assert device_id != registration_id @@ -102,7 +105,8 @@ def test_device_register_with_device_id_for_a_x509_individual_enrollment(): @pytest.mark.it( "A device gets provisioned to the linked IoTHub with device_id equal to the registration_id of the individual enrollment that has been created with a selfsigned X509 authentication" ) -def test_device_register_with_no_device_id_for_a_x509_individual_enrollment(): +@pytest.mark.parametrize("protocol", ["mqtt", "mqttws"]) +def test_device_register_with_no_device_id_for_a_x509_individual_enrollment(protocol): device_index = type_to_device_indices.get("individual_no_device_id")[0] try: @@ -114,7 +118,7 @@ def test_device_register_with_no_device_id_for_a_x509_individual_enrollment(): device_cert_file = "demoCA/newcerts/device_cert" + str(device_index) + ".pem" device_key_file = "demoCA/private/device_key" + str(device_index) + ".pem" registration_result = result_from_register( - registration_id, device_cert_file, device_key_file + registration_id, device_cert_file, device_key_file, protocol ) assert_device_provisioned( @@ -128,7 +132,10 @@ def test_device_register_with_no_device_id_for_a_x509_individual_enrollment(): @pytest.mark.it( "A group of devices get provisioned to the linked IoTHub with device_ids equal to the individual registration_ids inside a group enrollment that has been created with intermediate X509 authentication" ) -def test_group_of_devices_register_with_no_device_id_for_a_x509_intermediate_authentication_group_enrollment(): +@pytest.mark.parametrize("protocol", ["mqtt", "mqttws"]) +def test_group_of_devices_register_with_no_device_id_for_a_x509_intermediate_authentication_group_enrollment( + protocol +): group_id = "e2e-intermediate-hogwarts" + str(uuid.uuid4()) common_device_id = device_common_name devices_indices = type_to_device_indices.get("group_intermediate") @@ -159,6 +166,7 @@ def test_group_of_devices_register_with_no_device_id_for_a_x509_intermediate_aut device_key_input_file = common_device_key_input_file + str(index) + ".pem" device_cert_input_file = common_device_cert_input_file + str(index) + ".pem" device_inter_cert_chain_file = common_device_inter_cert_chain_file + str(index) + ".pem" + filenames = [device_cert_input_file, intermediate_cert_filename] with open(device_inter_cert_chain_file, "w") as outfile: for fname in filenames: @@ -169,6 +177,7 @@ def test_group_of_devices_register_with_no_device_id_for_a_x509_intermediate_aut registration_id=device_id, device_cert_file=device_inter_cert_chain_file, device_key_file=device_key_input_file, + protocol=protocol, ) assert_device_provisioned(device_id=device_id, registration_result=registration_result) @@ -187,7 +196,10 @@ def test_group_of_devices_register_with_no_device_id_for_a_x509_intermediate_aut @pytest.mark.it( "A group of devices get provisioned to the linked IoTHub with device_ids equal to the individual registration_ids inside a group enrollment that has been created with an already uploaded ca cert X509 authentication" ) -def test_group_of_devices_register_with_no_device_id_for_a_x509_ca_authentication_group_enrollment(): +@pytest.mark.parametrize("protocol", ["mqtt", "mqttws"]) +def test_group_of_devices_register_with_no_device_id_for_a_x509_ca_authentication_group_enrollment( + protocol +): group_id = "e2e-ca-beauxbatons" + str(uuid.uuid4()) common_device_id = device_common_name devices_indices = type_to_device_indices.get("group_ca") @@ -229,6 +241,7 @@ def test_group_of_devices_register_with_no_device_id_for_a_x509_ca_authenticatio registration_id=device_id, device_cert_file=device_inter_cert_chain_file, device_key_file=device_key_input_file, + protocol=protocol, ) assert_device_provisioned(device_id=device_id, registration_result=registration_result) @@ -276,14 +289,15 @@ def create_individual_enrollment_with_x509_client_certs(device_index, device_id= return service_client.create_or_update(individual_provisioning_model) -def result_from_register(registration_id, device_cert_file, device_key_file): +def result_from_register(registration_id, device_cert_file, device_key_file, protocol): x509 = X509(cert_file=device_cert_file, key_file=device_key_file, pass_phrase=device_password) - + protocol_boolean_mapping = {"mqtt": False, "mqttws": True} provisioning_device_client = ProvisioningDeviceClient.create_from_x509_certificate( provisioning_host=PROVISIONING_HOST, registration_id=registration_id, id_scope=ID_SCOPE, x509=x509, + websockets=protocol_boolean_mapping[protocol], ) return provisioning_device_client.register() diff --git a/azure_provisioning_e2e/tests/test_sync_symmetric_enrollments.py b/azure_provisioning_e2e/tests/test_sync_symmetric_enrollments.py index cfd284cf6..7d54b26ee 100644 --- a/azure_provisioning_e2e/tests/test_sync_symmetric_enrollments.py +++ b/azure_provisioning_e2e/tests/test_sync_symmetric_enrollments.py @@ -27,7 +27,8 @@ linked_iot_hub = connection_string_to_hostname(os.getenv("IOTHUB_CONNECTION_STRI @pytest.mark.it( "A device gets provisioned to the linked IoTHub with the device_id equal to the registration_id of the individual enrollment that has been created with a symmetric key authentication" ) -def test_device_register_with_no_device_id_for_a_symmetric_key_individual_enrollment(): +@pytest.mark.parametrize("protocol", ["mqtt", "mqttws"]) +def test_device_register_with_no_device_id_for_a_symmetric_key_individual_enrollment(protocol): try: individual_enrollment_record = create_individual_enrollment( "e2e-dps-underthewhompingwillow" + str(uuid.uuid4()) @@ -36,7 +37,7 @@ def test_device_register_with_no_device_id_for_a_symmetric_key_individual_enroll registration_id = individual_enrollment_record.registration_id symmetric_key = individual_enrollment_record.attestation.symmetric_key.primary_key - registration_result = result_from_register(registration_id, symmetric_key) + registration_result = result_from_register(registration_id, symmetric_key, protocol) assert_device_provisioned( device_id=registration_id, registration_result=registration_result @@ -49,7 +50,8 @@ def test_device_register_with_no_device_id_for_a_symmetric_key_individual_enroll @pytest.mark.it( "A device gets provisioned to the linked IoTHub with the user supplied device_id different from the registration_id of the individual enrollment that has been created with a symmetric key authentication" ) -def test_device_register_with_device_id_for_a_symmetric_key_individual_enrollment(): +@pytest.mark.parametrize("protocol", ["mqtt", "mqttws"]) +def test_device_register_with_device_id_for_a_symmetric_key_individual_enrollment(protocol): device_id = "e2edpstommarvoloriddle" try: @@ -60,7 +62,7 @@ def test_device_register_with_device_id_for_a_symmetric_key_individual_enrollmen registration_id = individual_enrollment_record.registration_id symmetric_key = individual_enrollment_record.attestation.symmetric_key.primary_key - registration_result = result_from_register(registration_id, symmetric_key) + registration_result = result_from_register(registration_id, symmetric_key, protocol) assert device_id != registration_id assert_device_provisioned(device_id=device_id, registration_result=registration_result) @@ -105,12 +107,14 @@ def assert_device_provisioned(device_id, registration_result): assert device.device_id == device_id -def result_from_register(registration_id, symmetric_key): +def result_from_register(registration_id, symmetric_key, protocol): + protocol_boolean_mapping = {"mqtt": False, "mqttws": True} provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key( provisioning_host=PROVISIONING_HOST, registration_id=registration_id, id_scope=ID_SCOPE, symmetric_key=symmetric_key, + websockets=protocol_boolean_mapping[protocol], ) return provisioning_device_client.register() diff --git a/credscan_suppression.json b/credscan_suppression.json new file mode 100644 index 000000000..3c738de48 --- /dev/null +++ b/credscan_suppression.json @@ -0,0 +1,14 @@ +{ + "tool": "Credential Scanner", + "suppressions": [ + { + "file": "\\azure_provisioning_e2e\\tests\\test_async_certificate_enrollments.py", + "_justification": "Test containing fake passwords and keys" + }, + { + "file": "\\azure_provisioning_e2e\\tests\\test_sync_certificate_enrollments.py", + "_justification": "Test containing fake passwords and keys" + } + ] + +} \ No newline at end of file diff --git a/migration_guide.md b/migration_guide.md new file mode 100644 index 000000000..076a467a5 --- /dev/null +++ b/migration_guide.md @@ -0,0 +1,147 @@ +# IoTHub Python SDK Migration Guide + +This guide details the migration plan to move from the IoTHub Python v1 code base to the new and improved v2 +code base. + +## Installing the IoTHub Python SDK + +- v1 + +```Shell +pip install azure-iothub-device-client + +``` + +- v2 + +```Shell +pip install azure-iot-device +``` + +## Creating a device client + +When creating a device client on the V1 client the protocol was specified on in the constructor. With the v2 SDK we are +currently only supporting the MQTT protocol so it only requires to supply the connection string when you create the client. + +### Symmetric Key authentication + +- v1 + +```Python + from iothub_client import IoTHubClient, IoTHubClientError, IoTHubTransportProvider, IoTHubClientResult + from iothub_client import IoTHubMessage, IoTHubMessageDispositionResult, IoTHubError, DeviceMethodReturnValue + + client = IoTHubClient(connection_string, IoTHubTransportProvider.MQTT) +``` + +- v2 + +```Python + from azure.iot.device.aio import IoTHubDeviceClient + from azure.iot.device import Message + + client = IoTHubDeviceClient.create_from_connection_string(connection_string) + await device_client.connect() +``` + +### x.509 authentication + +For x.509 device the v1 SDK required the user to supply the certificates in a call to set_options. Moving forward in the v2 +SDK, we only require for the user to call the create function with an x.509 object containing the path to the x.509 file and +key file with the optional pass phrase if neccessary. + +- v1 + +```Python + from iothub_client import IoTHubClient, IoTHubClientError, IoTHubTransportProvider, IoTHubClientResult + from iothub_client import IoTHubMessage, IoTHubMessageDispositionResult, IoTHubError, DeviceMethodReturnValue + + client = IoTHubClient(connection_string, IoTHubTransportProvider.MQTT) + # Get the x.509 certificate information + client.set_option("x509certificate", X509_CERTIFICATE) + client.set_option("x509privatekey", X509_PRIVATEKEY) +``` + +- v2 + +```Python + from azure.iot.device.aio import IoTHubDeviceClient + from azure.iot.device import Message + + # Get the x.509 certificate path from the environment + x509 = X509( + cert_file=os.getenv("X509_CERT_FILE"), + key_file=os.getenv("X509_KEY_FILE"), + pass_phrase=os.getenv("PASS_PHRASE") + ) + client = IoTHubDeviceClient.create_from_x509_certificate(hostname=hostname, device_id=device_id, x509=x509) + await device_client.connect() +``` + +## Sending Telemetry to IoTHub + +- v1 + +```Python + # create the device client + + message = IoTHubMessage("telemetry message") + message.message_id = "message id" + message.correlation_id = "correlation-id" + + prop_map = message.properties() + prop_map.add("property", "property_value") + client.send_event_async(message, send_confirmation_callback, user_ctx) +``` + +- v2 + +```Python + # create the device client + + message = Message("telemetry message") + message.message_id = "message id" + message.correlation_id = "correlation id" + + message.custom_properties["property"] = "property_value" + client.send_message(message) +``` + +## Receiving a Message from IoTHub + +- v1 + +```Python + # create the device client + + def receive_message_callback(message, counter): + global RECEIVE_CALLBACKS + message = message.get_bytearray() + size = len(message_buffer) + print ( "the data in the message received was : <<<%s>>> & Size=%d" % (message_buffer[:size].decode('utf-8'), size) ) + map_properties = message.properties() + key_value_pair = map_properties.get_internals() + print ( "custom properties are: %s" % key_value_pair ) + return IoTHubMessageDispositionResult.ACCEPTED + + client.set_message_callback(message_listener_callback, RECEIVE_CONTEXT) +``` + +- v2 + +```Python + # create the device client + + def message_listener(client): + while True: + message = client.receive_message() # blocking call + print("the data in the message received was ") + print(message.data) + print("custom properties are") + print(message.custom_properties) + + # Run a listener thread in the background + listen_thread = threading.Thread(target=message_listener, args=(device_client,)) + listen_thread.daemon = True + listen_thread.start() +``` diff --git a/requirements_test.txt b/requirements_test.txt index bee7189a8..88a1b686a 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -1,9 +1,9 @@ pytest -pytest-mock +pytest-mock==1.10.4 # breaking change (for us at least) in 1.11.0 renders our tests unpassable under 2.7 pytest-asyncio; python_version >= '3.5' pytest-testdox>=1.1.1 pytest-cov pytest-timeout mock #remove this as soon as no references to it remain in the code flake8 -azure-iothub-provisioningserviceclient >= 1.2.0 # Only needed for end to end tests for DPS +azure-iothub-provisioningserviceclient >= 1.2.0 # Only needed for end to end tests for DPS \ No newline at end of file diff --git a/scripts/build-release.ps1 b/scripts/build-release.ps1 index 26facf97e..3c4393aab 100644 --- a/scripts/build-release.ps1 +++ b/scripts/build-release.ps1 @@ -7,6 +7,10 @@ function Install-Dependencies { function Update-Version($part, $file) { bumpversion.exe $part --config-file .\.bumpverion.cfg --allow-dirty $file + + if($LASTEXITCODE -ne 0) { + throw "Bumpversion failed to increment part '$part' for '$file' with code ($LASTEXITCODE)" + } } function Invoke-Python { @@ -21,9 +25,9 @@ function Build { $sourceFiles = $env:sources # sdk repo top folder $dist = $env:dist # release artifacts top folder - # hashset key is package folder name in repo + # hashtable key is package folder name in repository root - $packages = @{ } # TODO add new packages to this list + $packages = @{ } # TODO add new packages to this hashtable $packages["azure-iot-device"] = [PSCustomObject]@{ File = "azure\iot\device\constant.py" @@ -35,6 +39,11 @@ function Build { Version = $env:nspkg_version_part } + $packages["azure-iot-hub"] = [PSCustomObject]@{ + File = "azure\iot\hub\constant.py" + Version = $env:hub_version_part + } + New-Item $dist -Force -ItemType Directory Install-Dependencies diff --git a/scripts/create_x509_chain_pipeline.py b/scripts/create_x509_chain_pipeline.py index 7cf9c7201..f8bb3cfcd 100644 --- a/scripts/create_x509_chain_pipeline.py +++ b/scripts/create_x509_chain_pipeline.py @@ -1,16 +1,20 @@ import os import re import base64 -import logging import shutil import subprocess - - -# TODO : Do we change all print statements to logging ? -logging.basicConfig(level=logging.DEBUG) +import argparse +import getpass def create_custom_config(): + """ + This function creates a custom configuration based on the already present openssl + configuration file present in local machine. The custom configuration is needed for + creating these certificates for sample and tests. + NOte : For this to work the local openssl conf file path needs to be stored in an + environment variable. + """ # The paths from different OS is different. # For example OS X path is "/usr/local/etc/openssl/openssl.cnf" # Windows path is "C:/Openssl/bin//openssl.cnf" etc @@ -59,64 +63,117 @@ def create_custom_config(): def create_verification_cert( - nonce, root_verify, ca_password="hogwarts", intermediate_password="hogwartsi", key_size=4096 + nonce, root_verify, ca_password=None, intermediate_password=None, key_size=4096 ): - + print(ca_password) print("Done generating verification key") - subject = "//C=US/CN=" + nonce - - if not root_verify: - os.system( - "openssl genrsa -out demoCA/private/verification_inter_key.pem" + " " + str(key_size) - ) - os.system( - "openssl req -key demoCA/private/verification_inter_key.pem" - + " " - + "-new -out demoCA/newcerts/verification_inter_csr.pem -subj " - + subject - ) - print("Done generating verification CSR for intermediate") - - os.system( - "openssl x509 -req -in demoCA/newcerts/verification_inter_csr.pem" - + " " - + "-CA demoCA/newcerts/intermediate_cert.pem -CAkey demoCA/private/intermediate_key.pem -passin pass:" - + intermediate_password - + " " - + "-CAcreateserial -out demoCA/newcerts/verification_inter_cert.pem -days 300 -sha256" - ) - print( - "Done generating verification certificate for intermediate. Upload to IoT Hub to verify" - ) + # subject = "//C=US/CN=" + nonce + subject = "/CN=" + nonce + if root_verify: + key_file = "demoCA/private/verification_root_key.pem" + csr_file = "demoCA/newcerts/verification_root_csr.pem" + in_key_file = "demoCA/private/ca_key.pem" + in_cert_file = "demoCA/newcerts/ca_cert.pem" + out_cert_file = "demoCA/newcerts/verification_root_cert.pem" + passphrase = ca_password else: - os.system( - "openssl genrsa -out demoCA/private/verification_root_key.pem" + " " + str(key_size) - ) - os.system( - "openssl req -key demoCA/private/verification_root_key.pem" - + " " - + "-new -out demoCA/newcerts/verification_root_csr.pem -subj " - + subject - ) - print("Done generating verification CSR") + key_file = "demoCA/private/verification_inter_key.pem" + csr_file = "demoCA/newcerts/verification_inter_csr.pem" + in_key_file = "demoCA/private/intermediate_key.pem" + in_cert_file = "demoCA/newcerts/intermediate_cert.pem" + out_cert_file = "demoCA/newcerts/verification_inter_cert.pem" + passphrase = intermediate_password - os.system( - "openssl x509 -req -in demoCA/newcerts/verification_root_csr.pem" - + " " - + "-CA demoCA/newcerts/ca_cert.pem -CAkey demoCA/private/ca_key.pem -passin pass:" - + ca_password - + " " - + "-CAcreateserial -out demoCA/newcerts/verification_root_cert.pem -days 300 -sha256" - ) + command_verification_key = ["openssl", "genrsa", "-out", key_file, str(key_size)] + + run_verification_key = subprocess.run( + command_verification_key, + universal_newlines=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + print_subprocess_output(run_verification_key) + + command_verification_csr = [ + "openssl", + "req", + "-key", + key_file, + "-new", + "-out", + csr_file, + "-subj", + subject, + ] + + run_verification_csr = subprocess.run( + command_verification_csr, + universal_newlines=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + print_subprocess_output(run_verification_csr) + + command_verification_cert = [ + "openssl", + "x509", + "-req", + "-in", + csr_file, + "-CA", + in_cert_file, + "-CAkey", + in_key_file, + "-passin", + "pass:" + passphrase, + "-CAcreateserial", + "-out", + out_cert_file, + "-days", + str(30), + "-sha256", + ] + + run_verification_cert = subprocess.run( + command_verification_cert, + universal_newlines=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + print_subprocess_output(run_verification_cert) + + if os.path.exists(out_cert_file): print("Done generating verification certificate. Upload to IoT Hub to verify") + else: + print("verification cert NOT generated") -def create_directories_and_prereq_files(): - # os.system("type nul > demoCA/index.txt") - # os.system("type nul > demoCA/index.txt.attr") - os.system("touch demoCA/index.txt") - # os.system("touch demoCA/index.txt.attr") +def print_subprocess_output(run_command): + print(run_command.stdout) + print(run_command.stderr) + print(run_command.returncode) + + +def create_directories_and_prereq_files(pipeline): + """ + This function creates the necessary directories and files. This needs to be called as the first step before doing anything. + :param pipeline: The boolean representing if function has been called from pipeline or not. True for pipeline, False for calling like a script. + """ + os.system("mkdir demoCA") + if pipeline: + # This command does not work when we run locally. So we have to pass in the pipeline variable + os.system("touch demoCA/index.txt") + # TODO Do we need this + # os.system("touch demoCA/index.txt.attr") + else: + os.system("type nul > demoCA/index.txt") + # TODO Do we need this + # os.system("type nul > demoCA/index.txt.attr") + os.system("echo 1000 > demoCA/serial") # Create this folder as configuration file makes new keys go here os.mkdir("demoCA/private") @@ -124,39 +181,86 @@ def create_directories_and_prereq_files(): os.mkdir("demoCA/newcerts") -def create_root(common_name, ca_password="hogwarts", key_size=4096, days=3650): - os.system( - "openssl genrsa -aes256 -out demoCA/private/ca_key.pem -passout pass:" - + ca_password - + " " - + str(key_size) - ) - print("Done generating root key") - # We need another argument like country as there is always error regarding the first argument - # Subject Attribute /C has no known NID, skipped - # So if the first arg is common name the error comes due to common name nad common name is not taken +def create_root(common_name, ca_password, key_size=4096, days=3650): + """ + This function creates the root key and the root certificate. - subject = "//C=US/CN=" + common_name - os.system( - "openssl req -config demoCA/openssl.cnf -key demoCA/private/ca_key.pem -passin pass:" - + ca_password - + " " - + "-new -x509 -days " - + str(days) - + " -sha256 -extensions v3_ca -out demoCA/newcerts/ca_cert.pem -subj " - + subject + :param common_name: The common name to be used in the subject. + :param ca_password: The password for the root certificate which is going to be referenced by the intermediate. + :param key_size: The key size to use for encryption. Default is 4096. + :param days: The number of days for which the certificate is valid. Default is 10 years (3650 days) + """ + command_root_key = [ + "openssl", + "genrsa", + "-aes256", + "-out", + "demoCA/private/ca_key.pem", + "-passout", + "pass:" + ca_password, + str(key_size), + ] + + run_root_key = subprocess.run( + command_root_key, universal_newlines=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) - print("Done generating root certificate") + + print_subprocess_output(run_root_key) + + if os.path.exists("demoCA/private/ca_key.pem"): + print("Done generating root key") + else: + print("root key NOT generated") + + subject = "/CN=" + common_name + + command_root_cert = [ + "openssl", + "req", + "-config", + "demoCA/openssl.cnf", + "-key", + "demoCA/private/ca_key.pem", + "-passin", + "pass:" + ca_password, + "-new", + "-x509", + "-days", + str(days), + "-sha256", + "-extensions", + "v3_ca", + "-out", + "demoCA/newcerts/ca_cert.pem", + "-subj", + subject, + ] + + run_root_cert = subprocess.run( + command_root_cert, universal_newlines=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + + print_subprocess_output(run_root_cert) + + if os.path.exists("demoCA/newcerts/ca_cert.pem"): + print("Done generating root cert") + else: + print("root cert NOT generated") def create_intermediate( - common_name, - pipeline, - ca_password="hogwarts", - intermediate_password="hogwartsi", - key_size=4096, - days=365, + common_name, pipeline, ca_password, intermediate_password, key_size=4096, days=365 ): + """ + This method will create an intermediate key, then an intermediate certificate request and finally an intermediate certificate. + :param common_name: The common name to be used in the subject. + :param pipeline: A boolean variable representing whether this script is being run in Azure Dev Ops pipeline or not. + When this function is called from Azure Dev Ops this variable is True otherwise False + :param ca_password: The password for the root certificate which is going to be referenced by the intermediate. + :param intermediate_password: The password for the intermediate certificate + :param key_size: The key size to use for encryption. Default is 4096. + :param days: The number of days for which the certificate is valid. Default is 1 year (365 days) + """ if pipeline: ca_cert = os.getenv("PROVISIONING_ROOT_CERT") @@ -183,32 +287,63 @@ def create_intermediate( in_cert_file_path = "demoCA/newcerts/ca_cert.pem" in_key_file_path = "demoCA/private/ca_key.pem" - os.system( - "openssl genrsa -aes256 -out demoCA/private/intermediate_key.pem -passout pass:" - + intermediate_password - + " " - + str(key_size) + command_intermediate_key = [ + "openssl", + "genrsa", + "-aes256", + "-out", + "demoCA/private/intermediate_key.pem", + "-passout", + "pass:" + intermediate_password, + str(key_size), + ] + + run_intermediate_key = subprocess.run( + command_intermediate_key, + universal_newlines=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, ) + + print_subprocess_output(run_intermediate_key) + if os.path.exists("demoCA/private/intermediate_key.pem"): print("Done generating intermediate key") else: print("intermediate key NOT generated") subject = "/CN=" + common_name - os.system( - "openssl req -config demoCA/openssl.cnf -key demoCA/private/intermediate_key.pem -passin pass:" - + intermediate_password - + " " - + "-new -sha256 -out demoCA/newcerts/intermediate_csr.pem -subj " - + subject + command_intermediate_csr = [ + "openssl", + "req", + "-config", + "demoCA/openssl.cnf", + "-key", + "demoCA/private/intermediate_key.pem", + "-passin", + "pass:" + intermediate_password, + "-new", + "-sha256", + "-out", + "demoCA/newcerts/intermediate_csr.pem", + "-subj", + subject, + ] + + run_intermediate_csr = subprocess.run( + command_intermediate_csr, + universal_newlines=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, ) + print_subprocess_output(run_intermediate_csr) if os.path.exists("demoCA/newcerts/intermediate_csr.pem"): print("Done generating intermediate CSR") else: print("intermediate csr NOT generated") - command = [ + command_intermediate_cert = [ "openssl", "ca", "-config", @@ -233,12 +368,13 @@ def create_intermediate( "-batch", ] - cp = subprocess.run( - command, universal_newlines=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE + run_intermediate_cert = subprocess.run( + command_intermediate_cert, + universal_newlines=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, ) - print(cp.stdout) - print(cp.stderr) - print(cp.returncode) + print_subprocess_output(run_intermediate_cert) if os.path.exists("demoCA/newcerts/intermediate_cert.pem"): print("Done generating intermediate certificate") @@ -249,17 +385,31 @@ def create_intermediate( def create_certificate_chain( common_name, ca_password, - intermediate_password="hogwartsi", - device_password="hogwartsd", + intermediate_password, + device_password, device_count=1, - pipeline=False, key_size=4096, days=365, ): + """ + This method will create a basic 3 layered chain certificate containing a root, then an intermediate and then some number of leaf certificates. + This function is only used when the certificates are created from script. + + :param common_name: The common name to be used in the subject. This is a single common name which would be applied to all certs created. Since this common name is meant for all, + this common name will be prepended by the words "root", "inter" and "device" for root, intermediate and device certificates. + For device certificates the common name will be further appended with the index of the device. + :param ca_password: The password for the root certificate which is going to be referenced by the intermediate. + :param intermediate_password: The password for the intermediate certificate + :param device_password: The password for the device certificate + :param device_count: The number of leaf devices for which that many number of certificates will be generated. + :param key_size: The key size to use for encryption. The default is 4096. + :param days: The number of days for which the certificate is valid. The default is 1 year or 365 days. + For the root cert this value is multiplied by 10. For the device certificates this number will be divided by 10. + """ common_name_for_root = "root" + common_name create_root(common_name_for_root, ca_password=ca_password, key_size=key_size, days=days * 10) - common_name_for_intermediate = "root" + common_name + common_name_for_intermediate = "inter" + common_name create_intermediate( common_name_for_intermediate, pipeline=False, @@ -279,88 +429,137 @@ def create_certificate_chain( intermediate_password=intermediate_password, device_password=device_password, key_size=key_size, - days=days, + days=int(days / 10), ) def create_leaf_certificates( index, common_name_for_all_device, - intermediate_password="hogwartsi", - device_password="hogwartsd", + intermediate_password, + device_password, key_size=4096, - days=365, + days=30, ): + """ + This function creates leaf or device certificates for a single device within a group represented + by the index in the group. + + :param index: The index representing the ith device in the group. + :param common_name_for_all_device: The common name to be used in the subject. This is applicable + of all the certificates created using this method. The common name will be appended by the + index to create an unique common name for each certificate. + :param intermediate_password: The password for the intermediate certificate + :param device_password: The password for the device certificate + :param key_size: The key size to use for encryption. The default is 4096. + :param days: The number of days for which the certificate is valid. The default is 1 month or 30 days. + """ key_file_name = "device_key" + str(index) + ".pem" csr_file_name = "device_csr" + str(index) + ".pem" cert_file_name = "device_cert" + str(index) + ".pem" - os.system( - "openssl genrsa -aes256 -out demoCA/private/" - + key_file_name - + " -passout pass:" - + device_password - + " " - + str(key_size) + command_device_key = [ + "openssl", + "genrsa", + "-aes256", + "-out", + "demoCA/private/" + key_file_name, + "-passout", + "pass:" + device_password, + str(key_size), + ] + + run_device_key = subprocess.run( + command_device_key, universal_newlines=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) + print(run_device_key.stdout) + print(run_device_key.stderr) + print(run_device_key.returncode) + if os.path.exists("demoCA/private/" + key_file_name): print("Done generating device key with filename {filename}".format(filename=key_file_name)) - logging.debug( - "Done generating device key with filename {filename}".format(filename=key_file_name) - ) else: print("device key NOT generated") - subject = "//C=US/CN=" + common_name_for_all_device + str(index) - os.system( - "openssl req -config demoCA/openssl.cnf -new -sha256 -key demoCA/private/" - + key_file_name - + " -passin pass:" - + device_password - + " " - + "-out demoCA/newcerts/" - + csr_file_name - + " -subj " - + subject + subject = "/CN=" + common_name_for_all_device + str(index) + command_device_csr = [ + "openssl", + "req", + "-config", + "demoCA/openssl.cnf", + "-key", + "demoCA/private/" + key_file_name, + "-passin", + "pass:" + device_password, + "-new", + "-sha256", + "-out", + "demoCA/newcerts/" + csr_file_name, + "-subj", + subject, + ] + + run_device_csr = subprocess.run( + command_device_csr, universal_newlines=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) + print(run_device_csr.stdout) + print(run_device_csr.stderr) + print(run_device_csr.returncode) + if os.path.exists("demoCA/newcerts/" + csr_file_name): print("Done generating device CSR with filename {filename}".format(filename=csr_file_name)) - logging.debug( - "Done generating device CSR with filename {filename}".format(filename=csr_file_name) - ) else: print("device CSR NOT generated") - os.system( - "openssl ca -config demoCA/openssl.cnf -in demoCA/newcerts/" - + csr_file_name - + " -out demoCA/newcerts/" - + cert_file_name - + " -keyfile demoCA/private/intermediate_key.pem -cert demoCA/newcerts/intermediate_cert.pem -passin pass:" - + intermediate_password - + " " - + "-extensions usr_cert -days " - + str(days) - + " -notext -md sha256 -batch" + command_device_cert = [ + "openssl", + "ca", + "-config", + "demoCA/openssl.cnf", + "-in", + "demoCA/newcerts/" + csr_file_name, + "-out", + "demoCA/newcerts/" + cert_file_name, + "-keyfile", + "demoCA/private/intermediate_key.pem", + "-cert", + "demoCA/newcerts/intermediate_cert.pem", + "-passin", + "pass:" + intermediate_password, + "-extensions", + "usr_cert", + "-days", + str(days), + "-notext", + "-md", + "sha256", + "-batch", + ] + + run_device_cert = subprocess.run( + command_device_cert, universal_newlines=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) + print(run_device_cert.stdout) + print(run_device_cert.stderr) + print(run_device_cert.returncode) if os.path.exists("demoCA/newcerts/" + cert_file_name): print( "Done generating device cert with filename {filename}".format(filename=cert_file_name) ) - logging.debug( - "Done generating device cert with filename {filename}".format(filename=cert_file_name) - ) else: print("device cert NOT generated") -def call_intermediate_cert_creation_from_pipeline( - common_name, ca_password, intermediate_password, key_size=4096, days=30 -): - os.system("mkdir demoCA") - create_directories_and_prereq_files() +def before_cert_creation_from_pipeline(): + """ + This function creates the required folder and files before creating certificates. + This also copies an openssl configurtaion file to be used for the generation of this certificates. + NOTE : This function is only applicable when called from the pipeline via E2E tests + and need not be used when it is called as a script. + """ + create_directories_and_prereq_files(True) shutil.copy("config/openssl.cnf", "demoCA/openssl.cnf") @@ -369,12 +568,25 @@ def call_intermediate_cert_creation_from_pipeline( else: print("Configuration file have NOT been copied") - print("ca_password={ca_password}".format(ca_password=ca_password)) - print( - "intermediate_password={intermediate_password}".format( - intermediate_password=intermediate_password - ) - ) + +def call_intermediate_cert_creation_from_pipeline( + common_name, ca_password, intermediate_password, key_size=4096, days=365 +): + """ + This function creates an intermediate certificate by getting called from the pipeline. + This method will create an intermediate key, then an intermediate certificate request and finally an intermediate certificate. + :param common_name: The common name to be used in the subject. + :param ca_password: The password for the root certificate which is going to be referenced by the intermediate. + :param intermediate_password: The password for the intermediate certificate + :param key_size: The key size to use for encryption. Default is 4096. + :param days: The number of days for which the certificate is valid. Default is 1 year (365 days) + :param common_name: The common name of the intermediate certificate. + :param ca_password: The password for the root ca certificate from which the intermediate certificate will be created. + :param intermediate_password: The password for the intermediate certificate. + :param key_size: The key size for the intermediate key. Default is 4096. + :param days: The number of days for hich + :return: + """ create_intermediate( common_name=common_name, @@ -386,27 +598,7 @@ def call_intermediate_cert_creation_from_pipeline( ) -def delete_directories_certs_created_from_pipeline(): - dirPath = "demoCA" - try: - shutil.rmtree(dirPath) - except Exception: - print("Error while deleting directory") - if os.path.exists("out_ca_cert.pem"): - os.remove("out_ca_cert.pem") - else: - print("The file does not exist") - if os.path.exists("out_ca_key.pem"): - os.remove("out_ca_key.pem") - else: - print("The file does not exist") - if os.path.exists(".rnd"): - os.remove(".rnd") - else: - print("The file does not exist") - - -def call_device_cert_creation_from_pipeline( +def create_device_certs( common_name, intermediate_password, device_password, key_size=4096, days=30, device_count=1 ): """ @@ -434,3 +626,167 @@ def call_device_cert_creation_from_pipeline( key_size=key_size, days=days, ) + + +def delete_directories_certs_created_from_pipeline(): + dirPath = "demoCA" + try: + shutil.rmtree(dirPath) + except Exception: + print("Error while deleting directory") + if os.path.exists("out_ca_cert.pem"): + os.remove("out_ca_cert.pem") + else: + print("The file does not exist") + if os.path.exists("out_ca_key.pem"): + os.remove("out_ca_key.pem") + else: + print("The file does not exist") + if os.path.exists(".rnd"): + os.remove(".rnd") + else: + print("The file does not exist") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Generate a certificate chain.") + parser.add_argument("domain", help="Domain name or common name.") + parser.add_argument( + "-s", + "--key-size", + type=int, + help="Size of the key in bits. 2048 bit is quite common. " + + "4096 bit is more secure and the default.", + ) + parser.add_argument( + "-d", + "--days", + type=int, + help="Validity time in days. Default is 10 years for root , 1 year for intermediate and 1 month for leaf", + ) + parser.add_argument( + "--ca-password", type=str, help="CA key password. If omitted it will be prompted." + ) + parser.add_argument( + "--intermediate-password", + type=str, + help="intermediate key password. If omitted it will be prompted.", + ) + parser.add_argument( + "--device-password", type=str, help="device key password. If omitted it will be prompted." + ) + + parser.add_argument( + "--device-count", type=str, help="Number of devices that present in a group. Default is 1." + ) + + parser.add_argument( + "--mode", + type=str, + help="The mode in which certificate is created. By default non-verification mode. For verification use 'verification'", + ) + parser.add_argument( + "--nonce", + type=str, + help="thumprint generated from iot hub certificates. During verification mode if omitted it will be prompted.", + ) + parser.add_argument( + "--root-verify", + type=str, + help="The boolean value to enter in case it is the root or intermediate verification. By default it is True meaning root verifictaion. If veriication of intermediate certification is needed please enter False ", + ) + args = parser.parse_args() + + common_name = args.domain + + if args.key_size: + key_size = args.key_size + else: + key_size = 4096 + if args.days: + days = args.days + else: + days = 30 + + ca_password = None + intermediate_password = None + if args.mode: + if args.mode == "verification": + mode = "verification" + print("in verification mode") + else: + raise ValueError( + "No other mode except verification is accepted. Default is non-verification" + ) + else: + mode = "non-verification" + + if mode == "non-verification": + if args.ca_password: + ca_password = args.ca_password + else: + ca_password = getpass.getpass("Enter pass phrase for root key: ") + if args.intermediate_password: + intermediate_password = args.intermediate_password + else: + intermediate_password = getpass.getpass("Enter pass phrase for intermediate key: ") + if args.device_password: + device_password = args.device_password + else: + device_password = getpass.getpass("Enter pass phrase for device key: ") + if args.device_count: + device_count = args.device_count + else: + device_count = 1 + + else: + print("in verification mode") + if args.nonce: + nonce = args.nonce + print("got nonce") + else: + nonce = getpass.getpass("Enter nonce for verification mode") + if args.root_verify: + lower_root_verify = args.root_verify.lower() + print("root verify is False") + if lower_root_verify == "false": + root_verify = False + if args.intermediate_password: + intermediate_password = args.intermediate_password + else: + intermediate_password = getpass.getpass( + "Enter pass phrase for intermediate key: " + ) + else: + root_verify = True + print("root verify is TRue") + if args.ca_password: + ca_password = args.ca_password + print("putting ca password") + else: + ca_password = getpass.getpass("Enter pass phrase for root key: ") + else: + root_verify = True + print("root verify is default TRue") + if args.ca_password: + ca_password = args.ca_password + else: + ca_password = getpass.getpass("Enter pass phrase for root key: ") + print(ca_password) + + if os.path.exists("demoCA/private/") and os.path.exists("demoCA/newcerts/"): + print("demoCA already exists.") + else: + create_directories_and_prereq_files(False) + create_custom_config() + + if mode == "verification": + create_verification_cert(nonce, root_verify, ca_password, intermediate_password) + else: + create_certificate_chain( + common_name=args.domain, + ca_password=ca_password, + intermediate_password=intermediate_password, + device_password=device_password, + device_count=int(device_count), + ) diff --git a/thirdpartynotice.txt b/thirdpartynotice.txt new file mode 100644 index 000000000..659ddb224 --- /dev/null +++ b/thirdpartynotice.txt @@ -0,0 +1,594 @@ + +Third Party Notices for Azure IoT SDKs project + +This project incorporates material from the project(s) listed below (collectively, “Third Party Code”). +Microsoft Corporation is not the original author of the Third Party Code. +The original copyright notice and license, under which Microsoft Corporation received such Third Party Code, +are set out below. This Third Party Code is licensed to you under their original license terms set forth below. +Microsoft Corporation reserves all other rights not expressly granted, whether by implication, estoppel or otherwise. + + +1.) License Notice for urllib3 from https://raw.githubusercontent.com/urllib3/urllib3/master/LICENSE.txt +----------------------------------------------------------------------------------------------------------------------- + +MIT License + +Copyright (c) 2008-2019 Andrey Petrov and contributors (see CONTRIBUTORS.txt) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +2.) License Notice for six from https://raw.githubusercontent.com/benjaminp/six/master/LICENSE +----------------------------------------------------------------------------------------------------------------------- + +Copyright (c) 2010-2019 Benjamin Peterson + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +3.) License Notice for paho-mqtt from https://raw.githubusercontent.com/PyCQA/astroid/master/COPYING +----------------------------------------------------------------------------------------------------------------------- + +This project is dual licensed under the Eclipse Public License 1.0 and the +Eclipse Distribution License 1.0 as described in the epl-v10 and edl-v10 files. + + +4.) License Notice for transitions from https://raw.githubusercontent.com/pytransitions/transitions/master/LICENSE +----------------------------------------------------------------------------------------------------------------------- + +The MIT License + +Copyright (c) 2014 - 2019 Tal Yarkoni, Alexander Neumann + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +5.) License Notice for requests from https://raw.githubusercontent.com/psf/requests/master/LICENSE +-------------------------------------------------------------------------------------------------------------------- + +Copyright 2019 Kenneth Reitz + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +6.) License Notice for requests-unixsocket from https://raw.githubusercontent.com/msabramo/requests-unixsocket/master/LICENSE +------------------------------------------------------------------------------------------------------------------------------ + +Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +7.) License Notice for janus from https://raw.githubusercontent.com/aio-libs/janus/master/LICENSE +------------------------------------------------------------------------------------------------------- + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2015-2018 Andrew Svetlov and aio-libs team + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +8.) License Notice for futures from https://raw.githubusercontent.com/agronholm/pythonfutures/master/LICENSE +----------------------------------------------------------------------------------------------------------------------- + +PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 +-------------------------------------------- + +1. This LICENSE AGREEMENT is between the Python Software Foundation +("PSF"), and the Individual or Organization ("Licensee") accessing and +otherwise using this software ("Python") in source or binary form and +its associated documentation. + +2. Subject to the terms and conditions of this License Agreement, PSF +hereby grants Licensee a nonexclusive, royalty-free, world-wide +license to reproduce, analyze, test, perform and/or display publicly, +prepare derivative works, distribute, and otherwise use Python +alone or in any derivative version, provided, however, that PSF's +License Agreement and PSF's notice of copyright, i.e., "Copyright (c) +2001, 2002, 2003, 2004, 2005, 2006 Python Software Foundation; All Rights +Reserved" are retained in Python alone or in any derivative version +prepared by Licensee. + +3. In the event Licensee prepares a derivative work that is based on +or incorporates Python or any part thereof, and wants to make +the derivative work available to others as provided herein, then +Licensee hereby agrees to include in any such work a brief summary of +the changes made to Python. + +4. PSF is making Python available to Licensee on an "AS IS" +basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +INFRINGE ANY THIRD PARTY RIGHTS. + +5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. + +6. This License Agreement will automatically terminate upon a material +breach of its terms and conditions. + +7. Nothing in this License Agreement shall be deemed to create any +relationship of agency, partnership, or joint venture between PSF and +Licensee. This License Agreement does not grant permission to use PSF +trademarks or trade name in a trademark sense to endorse or promote +products or services of Licensee, or any third party. + +8. By copying, installing or otherwise using Python, Licensee +agrees to be bound by the terms and conditions of this License +Agreement. + + +10.) License Notice for msrest from https://raw.githubusercontent.com/Azure/msrest-for-python/master/LICENSE.md +----------------------------------------------------------------------------------------------------------------------- + +MIT License + +Copyright (c) 2016 Microsoft Azure + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/tls_protocol_version_and_ciphers.md b/tls_protocol_version_and_ciphers.md new file mode 100644 index 000000000..494aece33 --- /dev/null +++ b/tls_protocol_version_and_ciphers.md @@ -0,0 +1,11 @@ +# IoT Python SDK support for TLS 1.2 + +## TLS Version + +The Python SDK fully supports TLS 1.2 in all of its APIs. + +Due to security concerns the Python SDK does not allow TLS 1.1 connections. + +## TLS Cipher Suites + +Coming Soon diff --git a/vsts/build-release.yml b/vsts/build-release.yml index e00e82619..797df8afc 100644 --- a/vsts/build-release.yml +++ b/vsts/build-release.yml @@ -33,6 +33,7 @@ steps: sources: $(Build.SourcesDirectory) device_version_part: $(azure-iot-device-version-part) nspkg_version_part: $(azure-iot-nspkg-version-part) + hub_version_part: $(azure-iot-hub-version-part) displayName: 'build release artifacts' - task: UsePythonVersion@0 diff --git a/vsts/build.yaml b/vsts/build.yaml index 8b230324c..8d396ddc3 100644 --- a/vsts/build.yaml +++ b/vsts/build.yaml @@ -26,15 +26,14 @@ jobs: matrix: Python27: python.version: '2.7' - Python34: - python.version: '3.4' Python35: python.version: '3.5' Python36: python.version: '3.6' Python37: python.version: '3.7' - maxParallel: 5 + Python38: + python.version: '3.8' steps: - task: UsePythonVersion@0 displayName: 'Use Python $(python.version)' diff --git a/vsts/horton-e2e.yaml b/vsts/horton-e2e.yaml index 0b0b6d72f..6bebca286 100644 --- a/vsts/horton-e2e.yaml +++ b/vsts/horton-e2e.yaml @@ -1,7 +1,7 @@ variables: Horton.FrameworkRoot: $(Agent.BuildDirectory)/e2e-fx Horton.FrameworkRef: master - Horton.Language: pythonpreview + Horton.Language: pythonv2 Horton.Repo: $(Build.Repository.Uri) Horton.Commit: $(Build.SourceBranch) Horton.ForcedImage: '' @@ -15,6 +15,6 @@ resources: endpoint: 'GitHub OAuth - az-iot-builder-01' jobs: -- template: vsts/templates/jobs-gate-pythonpreview.yaml@e2e_fx +- template: vsts/templates/jobs-gate-pythonv2.yaml@e2e_fx