зеркало из
1
0
Форкнуть 0

Merge branch 'ryanwinter/support' of https://github.com/Azure/azure-iot-sdk-python into ryanwinter/support

This commit is contained in:
Ryan Winter 2023-05-04 13:31:10 -07:00
Родитель 11eb1b8f5b f0b6f74fbf
Коммит de608f170e
17 изменённых файлов: 353 добавлений и 503 удалений

1
.github/ISSUE_TEMPLATE/bug-report.md поставляемый
Просмотреть файл

@ -27,6 +27,7 @@ Please follow the instructions and template below to save us time requesting add
- A detailed description. - A detailed description.
- A [Minimal Complete Reproducible Example](https://stackoverflow.com/help/mcve). This is code we can cut and paste into a readily available sample and run, or a link to a project you've written that we can compile to reproduce the bug. - A [Minimal Complete Reproducible Example](https://stackoverflow.com/help/mcve). This is code we can cut and paste into a readily available sample and run, or a link to a project you've written that we can compile to reproduce the bug.
- Console logs. - Console logs.
- If this is a connection related issue, include logs from the [Connection Diagnostic Tool](https://github.com/Azure/azure-iot-connection-diagnostic-tool)
5. Delete these instructions before submitting the bug. 5. Delete these instructions before submitting the bug.

Просмотреть файл

@ -31,10 +31,11 @@ Python 3.7 or higher is required in order to use the library
## Using the library ## Using the library
You can view the [**samples repository**](https://github.com/Azure/azure-iot-sdk-python/tree/main/samples) to see examples of SDK usage. You can view the [**samples directory**](https://github.com/Azure/azure-iot-sdk-python/tree/main/samples) to see examples of SDK usage.
Full API documentation for this package is available via [**Microsoft Docs**](https://docs.microsoft.com/python/api/azure-iot-device/azure.iot.device?view=azure-python). Note that this documentation may currently be out of date as v3.x.x is still in preview at the time of this writing. Full API documentation for this package is available via [**Microsoft Docs**](https://docs.microsoft.com/python/api/azure-iot-device/azure.iot.device?view=azure-python). Note that this documentation may currently be out of date as v3.x.x is still in preview at the time of this writing.
You can use the [**Connection Diagnostic Tool**](https://github.com/Azure/azure-iot-connection-diagnostic-tool) to help ascertain the cause of any connection issues you run into when using the SDK.
## Features ## Features

Просмотреть файл

@ -286,7 +286,8 @@ class IoTHubSession:
) )
finally: finally:
try: try:
await self._mqtt_client.disable_c2d_message_receive() if self._mqtt_client.connected:
await self._mqtt_client.disable_c2d_message_receive()
except mqtt.MQTTError: except mqtt.MQTTError:
# i.e. not connected # i.e. not connected
# This error would be expected if a disconnection has ocurred # This error would be expected if a disconnection has ocurred
@ -304,7 +305,8 @@ class IoTHubSession:
) )
finally: finally:
try: try:
await self._mqtt_client.disable_direct_method_request_receive() if self._mqtt_client.connected:
await self._mqtt_client.disable_direct_method_request_receive()
except mqtt.MQTTError: except mqtt.MQTTError:
# i.e. not connected # i.e. not connected
# This error would be expected if a disconnection has ocurred # This error would be expected if a disconnection has ocurred
@ -322,7 +324,8 @@ class IoTHubSession:
) )
finally: finally:
try: try:
await self._mqtt_client.disable_twin_patch_receive() if self._mqtt_client.connected:
await self._mqtt_client.disable_twin_patch_receive()
except mqtt.MQTTError: except mqtt.MQTTError:
# i.e. not connected # i.e. not connected
# This error would be expected if a disconnection has ocurred # This error would be expected if a disconnection has ocurred

Просмотреть файл

@ -106,7 +106,6 @@ async def main():
from azure.iot.device import IoTHubSession from azure.iot.device import IoTHubSession
async def main(): async def main():
async def main():
async with IoTHubSession.from_connection_string("<Your Connection String>") as session: async with IoTHubSession.from_connection_string("<Your Connection String>") as session:
async with session.messages() as messages: async with session.messages() as messages:
async for message in messages: async for message in messages:
@ -236,7 +235,7 @@ client = IoTHubDeviceClient.create_from_x509_certificate(
from azure.iot.device import IoTHubSession from azure.iot.device import IoTHubSession
import ssl import ssl
ssl_context = ssl.SSLContext.create_default_context() ssl_context = ssl.create_default_context()
ssl_context.load_cert_chain( ssl_context.load_cert_chain(
certfile="<Your X509 Cert File Path>", certfile="<Your X509 Cert File Path>",
keyfile="<Your X509 Key File>", keyfile="<Your X509 Key File>",
@ -257,7 +256,7 @@ Note that SSLContexts can be used with the `.from_connection_string()` factory
from azure.iot.device import IoTHubSession from azure.iot.device import IoTHubSession
import ssl import ssl
ssl_context = ssl.SSLContext.create_default_context() ssl_context = ssl.create_default_context()
ssl_context.load_cert_chain( ssl_context.load_cert_chain(
certfile="<Your X509 Cert File Path>", certfile="<Your X509 Cert File Path>",
keyfile="<Your X509 Key File>", keyfile="<Your X509 Key File>",
@ -294,7 +293,7 @@ client = IoTHubDeviceClient.create_from_connection_string(
from azure.iot.device import IoTHubSession from azure.iot.device import IoTHubSession
import ssl import ssl
ssl_context = ssl.SSLContext.create_default_context( ssl_context = ssl.create_default_context(
cafile="<Your CA Certificate File Path>", cafile="<Your CA Certificate File Path>",
) )
@ -320,7 +319,7 @@ client = IoTHubDeviceClient.create_from_connection_string(
from azure.iot.device import IoTHubSession from azure.iot.device import IoTHubSession
import ssl import ssl
ssl_context = ssl.SSLContext.create_default_context() ssl_context = ssl.create_default_context()
ssl_context.set_ciphers("<Your Cipher>") ssl_context.set_ciphers("<Your Cipher>")
client = IoTHubSession.from_connection_string( client = IoTHubSession.from_connection_string(

Просмотреть файл

@ -77,7 +77,7 @@ from azure.iot.device import ProvisioningSession
import ssl import ssl
async def main(): async def main():
ssl_context = ssl.SSLContext.create_default_context() ssl_context = ssl.create_default_context()
ssl_context.load_cert_chain( ssl_context.load_cert_chain(
certfile="<Your X509 Cert File Path>", certfile="<Your X509 Cert File Path>",
keyfile="<Your X509 Key File>", keyfile="<Your X509 Key File>",
@ -121,7 +121,7 @@ provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key(
from azure.iot.device import ProvisioningSession from azure.iot.device import ProvisioningSession
import ssl import ssl
ssl_context = ssl.SSLContext.create_default_context( ssl_context = ssl.create_default_context(
cafile="<Your CA Certificate File Path>", cafile="<Your CA Certificate File Path>",
) )
@ -153,7 +153,7 @@ provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key(
from azure.iot.device import ProvisioningSession from azure.iot.device import ProvisioningSession
import ssl import ssl
ssl_context = ssl.SSLContext.create_default_context() ssl_context = ssl.create_default_context()
ssl_context.set_ciphers("<Your Cipher>") ssl_context.set_ciphers("<Your Cipher>")
session = ProvisioningSession( session = ProvisioningSession(

Просмотреть файл

@ -3,7 +3,12 @@
# Licensed under the MIT License. See License.txt in the project root for # Licensed under the MIT License. See License.txt in the project root for
# license information. # license information.
# -------------------------------------------------------------------------- # --------------------------------------------------------------------------
"""This sample demonstrates a simple recurring telemetry using an IoTHubSession""" """
This sample demonstrates a simple recurring telemetry using an IoTHubSession
It's set to be used in the following MS Learn Tutorial:
https://learn.microsoft.com/en-us/azure/iot-develop/quickstart-send-telemetry-iot-hub?pivots=programming-language-python
"""
import asyncio import asyncio
import os import os

Просмотреть файл

@ -36,6 +36,8 @@ class TestMethods(object):
service_helper, service_helper,
leak_tracker, leak_tracker,
): ):
done_sending_response = asyncio.Event()
leak_tracker.set_initial_object_list() leak_tracker.set_initial_object_list()
registered = asyncio.Event() registered = asyncio.Event()
@ -53,7 +55,7 @@ class TestMethods(object):
async def method_listener(sess): async def method_listener(sess):
try: try:
nonlocal actual_request nonlocal actual_request, done_sending_response
async with sess.direct_method_requests() as requests: async with sess.direct_method_requests() as requests:
registered.set() registered.set()
async for request in requests: async for request in requests:
@ -65,6 +67,8 @@ class TestMethods(object):
request, method_response_status, response_payload request, method_response_status, response_payload
) )
) )
done_sending_response.set()
except asyncio.CancelledError: except asyncio.CancelledError:
# this happens during shutdown. no need to log this. # this happens during shutdown. no need to log this.
raise raise
@ -83,6 +87,13 @@ class TestMethods(object):
logger.info("Invoking method") logger.info("Invoking method")
method_response = await service_helper.invoke_method(method_name, request_payload) method_response = await service_helper.invoke_method(method_name, request_payload)
logger.info("Done Invoking method") logger.info("Done Invoking method")
# This is counterintuitive, Even though we've received the method response,
# we don't know if the client is done sending the response. This is because
# iothub returns the method repsonse immediately. It's possible that the
# PUBACK hasn't been received by the device client yet. We need to wait until
# the client receives the PUBACK before we exit.
await done_sending_response.wait()
logger.info("signal from listener received. Exiting session.")
assert session.connected is False assert session.connected is False
with pytest.raises(asyncio.CancelledError): with pytest.raises(asyncio.CancelledError):

Просмотреть файл

@ -1,22 +1,91 @@
# # ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# # Copyright (c) Microsoft Corporation. All rights reserved. # Copyright (c) Microsoft Corporation. All rights reserved.
# # Licensed under the MIT License. See License.txt in the project root for # Licensed under the MIT License. See License.txt in the project root for
# # license information. # license information.
# # -------------------------------------------------------------------------- # --------------------------------------------------------------------------
# from azure.iot.device.common.connection_string import ConnectionString """This module contains tools for working with Connection Strings"""
# from azure.iot.device.common.sastoken import SasToken
# __all__ = ["ConnectionString"]
#
# def connection_string_to_sas_token(conn_str): CS_DELIMITER = ";"
# """ CS_VAL_SEPARATOR = "="
# parse an IoTHub service connection string and return the host and a shared access
# signature that can be used to connect to the given hub HOST_NAME = "HostName"
# """ SHARED_ACCESS_KEY_NAME = "SharedAccessKeyName"
# conn_str_obj = ConnectionString(conn_str) SHARED_ACCESS_KEY = "SharedAccessKey"
# sas_token = SasToken( SHARED_ACCESS_SIGNATURE = "SharedAccessSignature"
# uri=conn_str_obj.get("HostName"), DEVICE_ID = "DeviceId"
# key=conn_str_obj.get("SharedAccessKey"), MODULE_ID = "ModuleId"
# key_name=conn_str_obj.get("SharedAccessKeyName"), GATEWAY_HOST_NAME = "GatewayHostName"
# )
# _valid_keys = [
# return {"host": conn_str_obj.get("HostName"), "sas": str(sas_token)} HOST_NAME,
SHARED_ACCESS_KEY_NAME,
SHARED_ACCESS_KEY,
SHARED_ACCESS_SIGNATURE,
DEVICE_ID,
MODULE_ID,
GATEWAY_HOST_NAME,
]
def _parse_connection_string(connection_string):
"""Return a dictionary of values contained in a given connection string"""
cs_args = connection_string.split(CS_DELIMITER)
d = dict(arg.split(CS_VAL_SEPARATOR, 1) for arg in cs_args)
if len(cs_args) != len(d):
# various errors related to incorrect parsing - duplicate args, bad syntax, etc.
raise ValueError("Invalid Connection String - Unable to parse")
if not all(key in _valid_keys for key in d.keys()):
raise ValueError("Invalid Connection String - Invalid Key")
_validate_keys(d)
return d
def _validate_keys(d):
"""Raise ValueError if incorrect combination of keys in dict d"""
host_name = d.get(HOST_NAME)
shared_access_key_name = d.get(SHARED_ACCESS_KEY_NAME)
shared_access_key = d.get(SHARED_ACCESS_KEY)
device_id = d.get(DEVICE_ID)
# This logic could be expanded to return the category of ConnectionString
if host_name and device_id and shared_access_key:
pass
elif host_name and shared_access_key and shared_access_key_name:
pass
else:
raise ValueError("Invalid Connection String - Incomplete")
class ConnectionString(object):
"""Key/value mappings for connection details.
Uses the same syntax as dictionary
"""
def __init__(self, connection_string):
"""Initializer for ConnectionString
:param str connection_string: String with connection details provided by Azure
:raises: ValueError if provided connection_string is invalid
"""
self._dict = _parse_connection_string(connection_string)
self._strrep = connection_string
def __getitem__(self, key):
return self._dict[key]
def __repr__(self):
return self._strrep
def get(self, key, default=None):
"""Return the value for key if key is in the dictionary, else default
:param str key: The key to retrieve a value for
:param str default: The default value returned if a key is not found
:returns: The value for the given key
"""
try:
return self._dict[key]
except KeyError:
return default

Просмотреть файл

@ -1,3 +1,3 @@
[pytest] [pytest]
addopts = --timeout 30 addopts = --timeout 90
asyncio_mode=auto asyncio_mode=auto

Просмотреть файл

@ -0,0 +1,81 @@
"""This module contains tools for working with Shared Access Signature (SAS) Tokens"""
import base64
import hmac
import hashlib
import time
import urllib.parse
class SasTokenError(Exception):
"""Error in SasToken"""
def __init__(self, message, cause=None):
"""Initializer for SasTokenError
:param str message: Error message
:param cause: Exception that caused this error (optional)
"""
super(SasTokenError, self).__init__(message)
self.cause = cause
class SasToken(object):
"""Shared Access Signature Token used to authenticate a request
Parameters:
uri (str): URI of the resouce to be accessed
key_name (str): Shared Access Key Name
key (str): Shared Access Key (base64 encoded)
ttl (int)[default 3600]: Time to live for the token, in seconds
Data Attributes:
expiry_time (int): Time that token will expire (in UTC, since epoch)
ttl (int): Time to live for the token, in seconds
Raises:
SasTokenError if trying to build a SasToken from invalid values
"""
_encoding_type = "utf-8"
_service_token_format = "SharedAccessSignature sr={}&sig={}&se={}&skn={}"
_device_token_format = "SharedAccessSignature sr={}&sig={}&se={}"
def __init__(self, uri, key, key_name=None, ttl=3600):
self._uri = urllib.parse.quote(uri, safe="")
self._key = key
self._key_name = key_name
self.ttl = ttl
self.refresh()
def __str__(self):
return self._token
def refresh(self):
"""
Refresh the SasToken lifespan, giving it a new expiry time
"""
self.expiry_time = int(time.time() + self.ttl)
self._token = self._build_token()
def _build_token(self):
"""Buid SasToken representation
Returns:
String representation of the token
"""
try:
message = (self._uri + "\n" + str(self.expiry_time)).encode(self._encoding_type)
signing_key = base64.b64decode(self._key.encode(self._encoding_type))
signed_hmac = hmac.HMAC(signing_key, message, hashlib.sha256)
signature = urllib.parse.quote(base64.b64encode(signed_hmac.digest()))
except (TypeError, base64.binascii.Error) as e:
raise SasTokenError("Unable to build SasToken from given values", e)
if self._key_name:
token = self._service_token_format.format(
self._uri, signature, str(self.expiry_time), self._key_name
)
else:
token = self._device_token_format.format(self._uri, signature, str(self.expiry_time))
return token

Просмотреть файл

@ -3,15 +3,17 @@
# Licensed under the MIT License. See License.txt in the project root for # Licensed under the MIT License. See License.txt in the project root for
# license information. # license information.
# -------------------------------------------------------------------------- # --------------------------------------------------------------------------
from typing import Optional
from provisioning_e2e.iothubservice20180630.iot_hub_gateway_service_ap_is20180630 import ( from .iothubservice20180630.iot_hub_gateway_service_ap_is20180630 import (
IotHubGatewayServiceAPIs20180630, IotHubGatewayServiceAPIs20180630,
) )
from msrest.exceptions import HttpOperationError from msrest.exceptions import HttpOperationError
from azure.iot.device.common.auth.connection_string import ConnectionString
from azure.iot.device.common.auth.sastoken import RenewableSasToken from .connection_string import ConnectionString
from azure.iot.device.common.auth.signing_mechanism import SymmetricKeySigningMechanism from .sastoken import SasToken
import uuid import uuid
import time import time
import random import random
@ -27,10 +29,9 @@ def connection_string_to_sas_token(conn_str):
signature that can be used to connect to the given hub signature that can be used to connect to the given hub
""" """
conn_str_obj = ConnectionString(conn_str) conn_str_obj = ConnectionString(conn_str)
signing_mechanism = SymmetricKeySigningMechanism(conn_str_obj.get("SharedAccessKey")) sas_token = SasToken(
sas_token = RenewableSasToken(
uri=conn_str_obj.get("HostName"), uri=conn_str_obj.get("HostName"),
signing_mechanism=signing_mechanism, key=conn_str_obj.get("SharedAccessKey"),
key_name=conn_str_obj.get("SharedAccessKeyName"), key_name=conn_str_obj.get("SharedAccessKeyName"),
) )
@ -46,6 +47,16 @@ def connection_string_to_hostname(conn_str):
return conn_str_obj.get("HostName") return conn_str_obj.get("HostName")
def _format_sas_uri(hostname: str, device_id: str, module_id: Optional[str]) -> str:
"""Format the SAS URI for using IoT Hub"""
if module_id:
return "{hostname}/devices/{device_id}/modules/{module_id}".format(
hostname=hostname, device_id=device_id, module_id=module_id
)
else:
return "{hostname}/devices/{device_id}".format(hostname=hostname, device_id=device_id)
def run_with_retry(fun, args, kwargs): def run_with_retry(fun, args, kwargs):
failures_left = max_failure_count failures_left = max_failure_count
retry = True retry = True
@ -70,7 +81,7 @@ def run_with_retry(fun, args, kwargs):
raise e raise e
class Helper: class ServiceRegistryHelper:
def __init__(self, service_connection_string): def __init__(self, service_connection_string):
self.cn = connection_string_to_sas_token(service_connection_string) self.cn = connection_string_to_sas_token(service_connection_string)
self.service = IotHubGatewayServiceAPIs20180630("https://" + self.cn["host"]).service self.service = IotHubGatewayServiceAPIs20180630("https://" + self.cn["host"]).service

Просмотреть файл

@ -13,6 +13,7 @@ from os.path import dirname as dir
# no longer true, we can get rid of this file. # no longer true, we can get rid of this file.
root_path = dir(dir(sys.path[0])) root_path = dir(dir(sys.path[0]))
script_path = os.path.join(root_path, "scripts") script_path = os.path.join(root_path, "scripts")
print("The path after scripts is")
print(script_path) print(script_path)
if script_path not in sys.path: if script_path not in sys.path:
sys.path.append(script_path) sys.path.append(script_path)

Просмотреть файл

@ -5,9 +5,9 @@
# -------------------------------------------------------------------------- # --------------------------------------------------------------------------
from provisioning_e2e.service_helper import Helper, connection_string_to_hostname from ..service_helper import ServiceRegistryHelper, connection_string_to_hostname
from azure.iot.device.aio import ProvisioningDeviceClient from azure.iot.device import ProvisioningSession
from azure.iot.device.common import X509
from provisioningserviceclient import ( from provisioningserviceclient import (
ProvisioningServiceClient, ProvisioningServiceClient,
IndividualEnrollment, IndividualEnrollment,
@ -18,11 +18,11 @@ import pytest
import logging import logging
import os import os
import uuid import uuid
import ssl
from . import path_adjust # noqa: F401 from . import path_adjust # noqa: F401
# Refers to an item in "scripts" in the root. This is made to work via the above path_adjust # Refers to an item in "scripts" in the root. This is made to work via the above path_adjust
from create_x509_chain_crypto import ( from scripts.create_x509_chain_crypto import (
before_cert_creation_from_pipeline, before_cert_creation_from_pipeline,
call_intermediate_cert_and_device_cert_creation_from_pipeline, call_intermediate_cert_and_device_cert_creation_from_pipeline,
delete_directories_certs_created_from_pipeline, delete_directories_certs_created_from_pipeline,
@ -40,7 +40,7 @@ device_password = "mortis"
service_client = ProvisioningServiceClient.create_from_connection_string( service_client = ProvisioningServiceClient.create_from_connection_string(
os.getenv("PROVISIONING_SERVICE_CONNECTION_STRING") os.getenv("PROVISIONING_SERVICE_CONNECTION_STRING")
) )
device_registry_helper = Helper(os.getenv("IOTHUB_CONNECTION_STRING")) device_registry_helper = ServiceRegistryHelper(os.getenv("IOTHUB_CONNECTION_STRING"))
linked_iot_hub = connection_string_to_hostname(os.getenv("IOTHUB_CONNECTION_STRING")) linked_iot_hub = connection_string_to_hostname(os.getenv("IOTHUB_CONNECTION_STRING"))
PROVISIONING_HOST = os.getenv("PROVISIONING_DEVICE_ENDPOINT") PROVISIONING_HOST = os.getenv("PROVISIONING_DEVICE_ENDPOINT")
@ -95,6 +95,7 @@ async def test_device_register_with_device_id_for_a_x509_individual_enrollment(p
registration_id, device_cert_file, device_key_file, protocol registration_id, device_cert_file, device_key_file, protocol
) )
assert registration_result is not None
assert device_id != registration_id assert device_id != registration_id
assert_device_provisioned(device_id=device_id, registration_result=registration_result) assert_device_provisioned(device_id=device_id, registration_result=registration_result)
device_registry_helper.try_delete_device(device_id) device_registry_helper.try_delete_device(device_id)
@ -121,6 +122,7 @@ async def test_device_register_with_no_device_id_for_a_x509_individual_enrollmen
registration_id, device_cert_file, device_key_file, protocol registration_id, device_cert_file, device_key_file, protocol
) )
assert registration_result is not None
assert_device_provisioned( assert_device_provisioned(
device_id=registration_id, registration_result=registration_result device_id=registration_id, registration_result=registration_result
) )
@ -179,6 +181,7 @@ async def test_group_of_devices_register_with_no_device_id_for_a_x509_intermedia
protocol=protocol, protocol=protocol,
) )
assert registration_result is not None
assert_device_provisioned(device_id=device_id, registration_result=registration_result) assert_device_provisioned(device_id=device_id, registration_result=registration_result)
device_registry_helper.try_delete_device(device_id) device_registry_helper.try_delete_device(device_id)
@ -242,7 +245,7 @@ async def test_group_of_devices_register_with_no_device_id_for_a_x509_ca_authent
device_key_file=device_key_input_file, device_key_file=device_key_input_file,
protocol=protocol, protocol=protocol,
) )
assert registration_result is not None
assert_device_provisioned(device_id=device_id, registration_result=registration_result) assert_device_provisioned(device_id=device_id, registration_result=registration_result)
device_registry_helper.try_delete_device(device_id) device_registry_helper.try_delete_device(device_id)
@ -258,9 +261,9 @@ def assert_device_provisioned(device_id, registration_result):
:param device_id: The device id :param device_id: The device id
:param registration_result: The registration result :param registration_result: The registration result
""" """
assert registration_result.status == "assigned" assert registration_result["status"] == "assigned"
assert registration_result.registration_state.device_id == device_id assert registration_result["registrationState"]["deviceId"] == device_id
assert registration_result.registration_state.assigned_hub == linked_iot_hub assert registration_result["registrationState"]["assignedHub"] == linked_iot_hub
device = device_registry_helper.get_device(device_id) device = device_registry_helper.get_device(device_id)
assert device is not None assert device is not None
@ -289,16 +292,31 @@ def create_individual_enrollment_with_x509_client_certs(device_index, device_id=
async def result_from_register(registration_id, device_cert_file, device_key_file, protocol): async def result_from_register(registration_id, device_cert_file, device_key_file, protocol):
x509 = X509(cert_file=device_cert_file, key_file=device_key_file, pass_phrase=device_password) # We have this mapping because the pytest logs look better with "mqtt" and "mqttws"
# instead of just "True" and "False".
protocol_boolean_mapping = {"mqtt": False, "mqttws": True} protocol_boolean_mapping = {"mqtt": False, "mqttws": True}
provisioning_device_client = ProvisioningDeviceClient.create_from_x509_certificate( ssl_context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT)
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
ssl_context.verify_mode = ssl.CERT_REQUIRED
ssl_context.check_hostname = True
ssl_context.load_default_certs()
ssl_context.load_cert_chain(
certfile=device_cert_file,
keyfile=device_key_file,
password=device_password,
)
async with ProvisioningSession(
provisioning_host=PROVISIONING_HOST, provisioning_host=PROVISIONING_HOST,
registration_id=registration_id, registration_id=registration_id,
id_scope=ID_SCOPE, id_scope=ID_SCOPE,
x509=x509, ssl_context=ssl_context,
websockets=protocol_boolean_mapping[protocol], websockets=protocol_boolean_mapping[protocol],
) ) as session:
print("Connected")
result = await provisioning_device_client.register() properties = {"Type": "Apple", "Sweet": True, "count": 5}
await provisioning_device_client.shutdown() result = await session.register(payload=properties)
return result print("Finished provisioning")
print(result)
result = await session.register()
return result if result is not None else None

Просмотреть файл

@ -4,8 +4,7 @@
# license information. # license information.
# -------------------------------------------------------------------------- # --------------------------------------------------------------------------
from provisioning_e2e.service_helper import Helper, connection_string_to_hostname from ..service_helper import ServiceRegistryHelper, connection_string_to_hostname
from azure.iot.device.aio import ProvisioningDeviceClient
from provisioningserviceclient import ProvisioningServiceClient, IndividualEnrollment from provisioningserviceclient import ProvisioningServiceClient, IndividualEnrollment
from provisioningserviceclient.protocol.models import AttestationMechanism, ReprovisionPolicy from provisioningserviceclient.protocol.models import AttestationMechanism, ReprovisionPolicy
import pytest import pytest
@ -13,6 +12,8 @@ import logging
import os import os
import uuid import uuid
from azure.iot.device import ProvisioningSession
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
@ -23,7 +24,7 @@ service_client = ProvisioningServiceClient.create_from_connection_string(
os.getenv("PROVISIONING_SERVICE_CONNECTION_STRING") os.getenv("PROVISIONING_SERVICE_CONNECTION_STRING")
) )
service_client = ProvisioningServiceClient.create_from_connection_string(conn_str) service_client = ProvisioningServiceClient.create_from_connection_string(conn_str)
device_registry_helper = Helper(os.getenv("IOTHUB_CONNECTION_STRING")) device_registry_helper = ServiceRegistryHelper(os.getenv("IOTHUB_CONNECTION_STRING"))
linked_iot_hub = connection_string_to_hostname(os.getenv("IOTHUB_CONNECTION_STRING")) linked_iot_hub = connection_string_to_hostname(os.getenv("IOTHUB_CONNECTION_STRING"))
@ -45,6 +46,7 @@ async def test_device_register_with_no_device_id_for_a_symmetric_key_individual_
registration_result = await result_from_register(registration_id, symmetric_key, protocol) registration_result = await result_from_register(registration_id, symmetric_key, protocol)
assert registration_result is not None
assert_device_provisioned( assert_device_provisioned(
device_id=registration_id, registration_result=registration_result device_id=registration_id, registration_result=registration_result
) )
@ -70,6 +72,7 @@ async def test_device_register_with_device_id_for_a_symmetric_key_individual_enr
registration_result = await result_from_register(registration_id, symmetric_key, protocol) registration_result = await result_from_register(registration_id, symmetric_key, protocol)
assert registration_result is not None
assert device_id != registration_id assert device_id != registration_id
assert_device_provisioned(device_id=device_id, registration_result=registration_result) assert_device_provisioned(device_id=device_id, registration_result=registration_result)
device_registry_helper.try_delete_device(device_id) device_registry_helper.try_delete_device(device_id)
@ -103,9 +106,10 @@ def assert_device_provisioned(device_id, registration_result):
:param device_id: The device id :param device_id: The device id
:param registration_result: The registration result :param registration_result: The registration result
""" """
assert registration_result.status == "assigned" print(registration_result)
assert registration_result.registration_state.device_id == device_id assert registration_result["status"] == "assigned"
assert registration_result.registration_state.assigned_hub == linked_iot_hub assert registration_result["registrationState"]["deviceId"] == device_id
assert registration_result["registrationState"]["assignedHub"] == linked_iot_hub
device = device_registry_helper.get_device(device_id) device = device_registry_helper.get_device(device_id)
assert device is not None assert device is not None
@ -113,19 +117,16 @@ def assert_device_provisioned(device_id, registration_result):
assert device.device_id == device_id assert device.device_id == device_id
# TODO Eventually should return result after the APi changes
async def result_from_register(registration_id, symmetric_key, protocol): async def result_from_register(registration_id, symmetric_key, protocol):
# We have this mapping because the pytest logs look better with "mqtt" and "mqttws" # We have this mapping because the pytest logs look better with "mqtt" and "mqttws"
# instead of just "True" and "False". # instead of just "True" and "False".
protocol_boolean_mapping = {"mqtt": False, "mqttws": True} protocol_boolean_mapping = {"mqtt": False, "mqttws": True}
provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key( async with ProvisioningSession(
provisioning_host=PROVISIONING_HOST, provisioning_host=PROVISIONING_HOST,
registration_id=registration_id, registration_id=registration_id,
id_scope=ID_SCOPE, id_scope=ID_SCOPE,
symmetric_key=symmetric_key, shared_access_key=symmetric_key,
websockets=protocol_boolean_mapping[protocol], websockets=protocol_boolean_mapping[protocol],
) ) as session:
result = await session.register()
result = await provisioning_device_client.register() return result if result is not None else None
await provisioning_device_client.shutdown()
return result

Просмотреть файл

@ -1,304 +0,0 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from provisioning_e2e.service_helper import Helper, connection_string_to_hostname
from azure.iot.device import ProvisioningDeviceClient
from azure.iot.device.common import X509
from provisioningserviceclient import (
ProvisioningServiceClient,
IndividualEnrollment,
EnrollmentGroup,
)
from provisioningserviceclient.protocol.models import AttestationMechanism, ReprovisionPolicy
import pytest
import logging
import os
import uuid
from . import path_adjust # noqa: F401
# Refers to an item in "scripts" in the root. This is made to work via the above path_adjust
from create_x509_chain_crypto import (
before_cert_creation_from_pipeline,
call_intermediate_cert_and_device_cert_creation_from_pipeline,
delete_directories_certs_created_from_pipeline,
)
logging.basicConfig(level=logging.DEBUG)
intermediate_common_name = "e2edpswingardium"
intermediate_password = "leviosa"
device_common_name = "e2edpsexpecto" + str(uuid.uuid4())
device_password = "patronum"
service_client = ProvisioningServiceClient.create_from_connection_string(
os.getenv("PROVISIONING_SERVICE_CONNECTION_STRING")
)
device_registry_helper = Helper(os.getenv("IOTHUB_CONNECTION_STRING"))
linked_iot_hub = connection_string_to_hostname(os.getenv("IOTHUB_CONNECTION_STRING"))
PROVISIONING_HOST = os.getenv("PROVISIONING_DEVICE_ENDPOINT")
ID_SCOPE = os.getenv("PROVISIONING_DEVICE_IDSCOPE")
certificate_count = 8
type_to_device_indices = {
"individual_with_device_id": [1],
"individual_no_device_id": [2],
"group_intermediate": [3, 4, 5],
"group_ca": [6, 7, 8],
}
@pytest.fixture(scope="module", autouse=True)
def before_all_tests(request):
logging.info("set up certificates before cert related tests")
before_cert_creation_from_pipeline()
call_intermediate_cert_and_device_cert_creation_from_pipeline(
intermediate_common_name=intermediate_common_name,
device_common_name=device_common_name,
ca_password=os.getenv("PROVISIONING_ROOT_PASSWORD"),
intermediate_password=intermediate_password,
device_password=device_password,
device_count=8,
)
def after_module():
logging.info("tear down certificates after cert related tests")
delete_directories_certs_created_from_pipeline()
request.addfinalizer(after_module)
@pytest.mark.it(
"A device gets provisioned to the linked IoTHub with the user supplied device_id different from the registration_id of the individual enrollment that has been created with a selfsigned X509 authentication"
)
@pytest.mark.parametrize("protocol", ["mqtt", "mqttws"])
def test_device_register_with_device_id_for_a_x509_individual_enrollment(protocol):
device_id = "e2edpsflyingfeather"
device_index = type_to_device_indices.get("individual_with_device_id")[0]
try:
individual_enrollment_record = create_individual_enrollment_with_x509_client_certs(
device_index=device_index, device_id=device_id
)
registration_id = individual_enrollment_record.registration_id
device_cert_file = "demoCA/newcerts/device_cert" + str(device_index) + ".pem"
device_key_file = "demoCA/private/device_key" + str(device_index) + ".pem"
registration_result = result_from_register(
registration_id, device_cert_file, device_key_file, protocol
)
assert device_id != registration_id
assert_device_provisioned(device_id=device_id, registration_result=registration_result)
device_registry_helper.try_delete_device(device_id)
finally:
service_client.delete_individual_enrollment_by_param(registration_id)
@pytest.mark.it(
"A device gets provisioned to the linked IoTHub with device_id equal to the registration_id of the individual enrollment that has been created with a selfsigned X509 authentication"
)
@pytest.mark.parametrize("protocol", ["mqtt", "mqttws"])
def test_device_register_with_no_device_id_for_a_x509_individual_enrollment(protocol):
device_index = type_to_device_indices.get("individual_no_device_id")[0]
try:
individual_enrollment_record = create_individual_enrollment_with_x509_client_certs(
device_index=device_index
)
registration_id = individual_enrollment_record.registration_id
device_cert_file = "demoCA/newcerts/device_cert" + str(device_index) + ".pem"
device_key_file = "demoCA/private/device_key" + str(device_index) + ".pem"
registration_result = result_from_register(
registration_id, device_cert_file, device_key_file, protocol
)
assert_device_provisioned(
device_id=registration_id, registration_result=registration_result
)
device_registry_helper.try_delete_device(registration_id)
finally:
service_client.delete_individual_enrollment_by_param(registration_id)
@pytest.mark.it(
"A group of devices get provisioned to the linked IoTHub with device_ids equal to the individual registration_ids inside a group enrollment that has been created with intermediate X509 authentication"
)
@pytest.mark.parametrize("protocol", ["mqtt", "mqttws"])
def test_group_of_devices_register_with_no_device_id_for_a_x509_intermediate_authentication_group_enrollment(
protocol,
):
group_id = "e2e-intermediate-hogwarts" + str(uuid.uuid4())
common_device_id = device_common_name
devices_indices = type_to_device_indices.get("group_intermediate")
device_count_in_group = len(devices_indices)
reprovision_policy = ReprovisionPolicy(migrate_device_data=True)
try:
intermediate_cert_filename = "demoCA/newcerts/intermediate_cert.pem"
with open(intermediate_cert_filename, "r") as intermediate_pem:
intermediate_cert_content = intermediate_pem.read()
attestation_mechanism = AttestationMechanism.create_with_x509_signing_certs(
intermediate_cert_content
)
enrollment_group_provisioning_model = EnrollmentGroup.create(
group_id, attestation=attestation_mechanism, reprovision_policy=reprovision_policy
)
service_client.create_or_update(enrollment_group_provisioning_model)
count = 0
common_device_key_input_file = "demoCA/private/device_key"
common_device_cert_input_file = "demoCA/newcerts/device_cert"
common_device_inter_cert_chain_file = "demoCA/newcerts/out_inter_device_chain_cert"
for index in devices_indices:
count = count + 1
device_id = common_device_id + str(index)
device_key_input_file = common_device_key_input_file + str(index) + ".pem"
device_cert_input_file = common_device_cert_input_file + str(index) + ".pem"
device_inter_cert_chain_file = common_device_inter_cert_chain_file + str(index) + ".pem"
filenames = [device_cert_input_file, intermediate_cert_filename]
with open(device_inter_cert_chain_file, "w") as outfile:
for fname in filenames:
with open(fname) as infile:
outfile.write(infile.read())
registration_result = result_from_register(
registration_id=device_id,
device_cert_file=device_inter_cert_chain_file,
device_key_file=device_key_input_file,
protocol=protocol,
)
assert_device_provisioned(device_id=device_id, registration_result=registration_result)
device_registry_helper.try_delete_device(device_id)
# Make sure space is okay. The following line must be outside for loop.
assert count == device_count_in_group
finally:
service_client.delete_enrollment_group_by_param(group_id)
@pytest.mark.skip(
reason="The enrollment is never properly created on the pipeline and it is always created without any CA reference and eventually the registration fails"
)
@pytest.mark.it(
"A group of devices get provisioned to the linked IoTHub with device_ids equal to the individual registration_ids inside a group enrollment that has been created with an already uploaded ca cert X509 authentication"
)
@pytest.mark.parametrize("protocol", ["mqtt", "mqttws"])
def test_group_of_devices_register_with_no_device_id_for_a_x509_ca_authentication_group_enrollment(
protocol,
):
group_id = "e2e-ca-beauxbatons" + str(uuid.uuid4())
common_device_id = device_common_name
devices_indices = type_to_device_indices.get("group_ca")
device_count_in_group = len(devices_indices)
reprovision_policy = ReprovisionPolicy(migrate_device_data=True)
try:
DPS_GROUP_CA_CERT = os.getenv("PROVISIONING_ROOT_CERT")
attestation_mechanism = AttestationMechanism.create_with_x509_ca_refs(
ref1=DPS_GROUP_CA_CERT
)
enrollment_group_provisioning_model = EnrollmentGroup.create(
group_id, attestation=attestation_mechanism, reprovision_policy=reprovision_policy
)
service_client.create_or_update(enrollment_group_provisioning_model)
count = 0
intermediate_cert_filename = "demoCA/newcerts/intermediate_cert.pem"
common_device_key_input_file = "demoCA/private/device_key"
common_device_cert_input_file = "demoCA/newcerts/device_cert"
common_device_inter_cert_chain_file = "demoCA/newcerts/out_inter_device_chain_cert"
for index in devices_indices:
count = count + 1
device_id = common_device_id + str(index)
device_key_input_file = common_device_key_input_file + str(index) + ".pem"
device_cert_input_file = common_device_cert_input_file + str(index) + ".pem"
device_inter_cert_chain_file = common_device_inter_cert_chain_file + str(index) + ".pem"
filenames = [device_cert_input_file, intermediate_cert_filename]
with open(device_inter_cert_chain_file, "w") as outfile:
for fname in filenames:
with open(fname) as infile:
logging.debug("Filename is {}".format(fname))
content = infile.read()
logging.debug(content)
outfile.write(content)
registration_result = result_from_register(
registration_id=device_id,
device_cert_file=device_inter_cert_chain_file,
device_key_file=device_key_input_file,
protocol=protocol,
)
assert_device_provisioned(device_id=device_id, registration_result=registration_result)
device_registry_helper.try_delete_device(device_id)
# Make sure space is okay. The following line must be outside for loop.
assert count == device_count_in_group
finally:
service_client.delete_enrollment_group_by_param(group_id)
def assert_device_provisioned(device_id, registration_result):
"""
Assert that the device has been provisioned correctly to iothub from the registration result as well as from the device registry
:param device_id: The device id
:param registration_result: The registration result
"""
assert registration_result.status == "assigned"
assert registration_result.registration_state.device_id == device_id
assert registration_result.registration_state.assigned_hub == linked_iot_hub
device = device_registry_helper.get_device(device_id)
assert device is not None
assert device.authentication.type == "selfSigned"
assert device.device_id == device_id
def create_individual_enrollment_with_x509_client_certs(device_index, device_id=None):
registration_id = device_common_name + str(device_index)
reprovision_policy = ReprovisionPolicy(migrate_device_data=True)
device_cert_input_file = "demoCA/newcerts/device_cert" + str(device_index) + ".pem"
with open(device_cert_input_file, "r") as in_device_cert:
device_cert_content = in_device_cert.read()
attestation_mechanism = AttestationMechanism.create_with_x509_client_certs(device_cert_content)
individual_provisioning_model = IndividualEnrollment.create(
attestation=attestation_mechanism,
registration_id=registration_id,
reprovision_policy=reprovision_policy,
device_id=device_id,
)
return service_client.create_or_update(individual_provisioning_model)
def result_from_register(registration_id, device_cert_file, device_key_file, protocol):
x509 = X509(cert_file=device_cert_file, key_file=device_key_file, pass_phrase=device_password)
protocol_boolean_mapping = {"mqtt": False, "mqttws": True}
provisioning_device_client = ProvisioningDeviceClient.create_from_x509_certificate(
provisioning_host=PROVISIONING_HOST,
registration_id=registration_id,
id_scope=ID_SCOPE,
x509=x509,
websockets=protocol_boolean_mapping[protocol],
)
result = provisioning_device_client.register()
provisioning_device_client.shutdown()
return result

Просмотреть файл

@ -1,122 +0,0 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from provisioning_e2e.service_helper import Helper, connection_string_to_hostname
from azure.iot.device import ProvisioningDeviceClient
from provisioningserviceclient import ProvisioningServiceClient, IndividualEnrollment
from provisioningserviceclient.protocol.models import AttestationMechanism, ReprovisionPolicy
import pytest
import logging
import os
import uuid
logging.basicConfig(level=logging.DEBUG)
PROVISIONING_HOST = os.getenv("PROVISIONING_DEVICE_ENDPOINT")
ID_SCOPE = os.getenv("PROVISIONING_DEVICE_IDSCOPE")
service_client = ProvisioningServiceClient.create_from_connection_string(
os.getenv("PROVISIONING_SERVICE_CONNECTION_STRING")
)
device_registry_helper = Helper(os.getenv("IOTHUB_CONNECTION_STRING"))
linked_iot_hub = connection_string_to_hostname(os.getenv("IOTHUB_CONNECTION_STRING"))
@pytest.mark.it(
"A device gets provisioned to the linked IoTHub with the device_id equal to the registration_id of the individual enrollment that has been created with a symmetric key authentication"
)
@pytest.mark.parametrize("protocol", ["mqtt", "mqttws"])
def test_device_register_with_no_device_id_for_a_symmetric_key_individual_enrollment(protocol):
try:
individual_enrollment_record = create_individual_enrollment(
"e2e-dps-underthewhompingwillow" + str(uuid.uuid4())
)
registration_id = individual_enrollment_record.registration_id
symmetric_key = individual_enrollment_record.attestation.symmetric_key.primary_key
registration_result = result_from_register(registration_id, symmetric_key, protocol)
assert_device_provisioned(
device_id=registration_id, registration_result=registration_result
)
device_registry_helper.try_delete_device(registration_id)
finally:
service_client.delete_individual_enrollment_by_param(registration_id)
@pytest.mark.it(
"A device gets provisioned to the linked IoTHub with the user supplied device_id different from the registration_id of the individual enrollment that has been created with a symmetric key authentication"
)
@pytest.mark.parametrize("protocol", ["mqtt", "mqttws"])
def test_device_register_with_device_id_for_a_symmetric_key_individual_enrollment(protocol):
device_id = "e2edpstommarvoloriddle"
try:
individual_enrollment_record = create_individual_enrollment(
registration_id="e2e-dps-prioriincantatem" + str(uuid.uuid4()), device_id=device_id
)
registration_id = individual_enrollment_record.registration_id
symmetric_key = individual_enrollment_record.attestation.symmetric_key.primary_key
registration_result = result_from_register(registration_id, symmetric_key, protocol)
assert device_id != registration_id
assert_device_provisioned(device_id=device_id, registration_result=registration_result)
device_registry_helper.try_delete_device(device_id)
finally:
service_client.delete_individual_enrollment_by_param(registration_id)
def create_individual_enrollment(registration_id, device_id=None):
"""
Create an individual enrollment record using the service client
:param registration_id: The registration id of the enrollment
:param device_id: Optional device id
:return: And individual enrollment record
"""
reprovision_policy = ReprovisionPolicy(migrate_device_data=True)
attestation_mechanism = AttestationMechanism(type="symmetricKey")
individual_provisioning_model = IndividualEnrollment.create(
attestation=attestation_mechanism,
registration_id=registration_id,
device_id=device_id,
reprovision_policy=reprovision_policy,
)
return service_client.create_or_update(individual_provisioning_model)
def assert_device_provisioned(device_id, registration_result):
"""
Assert that the device has been provisioned correctly to iothub from the registration result as well as from the device registry
:param device_id: The device id
:param registration_result: The registration result
"""
assert registration_result.status == "assigned"
assert registration_result.registration_state.device_id == device_id
assert registration_result.registration_state.assigned_hub == linked_iot_hub
device = device_registry_helper.get_device(device_id)
assert device is not None
assert device.authentication.type == "sas"
assert device.device_id == device_id
def result_from_register(registration_id, symmetric_key, protocol):
protocol_boolean_mapping = {"mqtt": False, "mqttws": True}
provisioning_device_client = ProvisioningDeviceClient.create_from_symmetric_key(
provisioning_host=PROVISIONING_HOST,
registration_id=registration_id,
id_scope=ID_SCOPE,
symmetric_key=symmetric_key,
websockets=protocol_boolean_mapping[protocol],
)
result = provisioning_device_client.register()
provisioning_device_client.shutdown()
return result

Просмотреть файл

@ -160,6 +160,12 @@ sk_sm_create_exceptions = [
pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"),
] ]
# Does the session exit gracefully or because of error?
graceful_exit_params = [
pytest.param(True, id="graceful exit"),
pytest.param(False, id="exit because of exception"),
]
@pytest.mark.describe("IoTHubSession -- Instantiation") @pytest.mark.describe("IoTHubSession -- Instantiation")
class TestIoTHubSessionInstantiation: class TestIoTHubSessionInstantiation:
@ -1926,6 +1932,29 @@ class TestIoTHubSessionMessages:
assert session._mqtt_client.enable_c2d_message_receive.await_count == 1 assert session._mqtt_client.enable_c2d_message_receive.await_count == 1
assert session._mqtt_client.disable_c2d_message_receive.await_count == 1 assert session._mqtt_client.disable_c2d_message_receive.await_count == 1
@pytest.mark.it(
"Does not attempt to disable C2D message receive upon exit if IoTHubMQTTClient is disconnected"
)
@pytest.mark.parametrize("graceful_exit", graceful_exit_params)
async def test_context_manager_exit_while_disconnected(
self, session, arbitrary_exception, graceful_exit
):
assert session._mqtt_client.enable_c2d_message_receive.await_count == 0
assert session._mqtt_client.disable_c2d_message_receive.await_count == 0
try:
async with session.messages():
assert session._mqtt_client.enable_c2d_message_receive.await_count == 1
assert session._mqtt_client.disable_c2d_message_receive.await_count == 0
session._mqtt_client.connected = False
if not graceful_exit:
raise arbitrary_exception
except type(arbitrary_exception):
pass
assert session._mqtt_client.enable_c2d_message_receive.await_count == 1
assert session._mqtt_client.disable_c2d_message_receive.await_count == 0
@pytest.mark.it( @pytest.mark.it(
"Yields an AsyncGenerator that yields the C2D messages yielded by the IoTHubMQTTClient's incoming C2D message generator" "Yields an AsyncGenerator that yields the C2D messages yielded by the IoTHubMQTTClient's incoming C2D message generator"
) )
@ -2094,6 +2123,29 @@ class TestIoTHubSessionDirectMethodRequests:
assert session._mqtt_client.enable_direct_method_request_receive.await_count == 1 assert session._mqtt_client.enable_direct_method_request_receive.await_count == 1
assert session._mqtt_client.disable_direct_method_request_receive.await_count == 1 assert session._mqtt_client.disable_direct_method_request_receive.await_count == 1
@pytest.mark.it(
"Does not attempt to disable direct method request receive upon exit if IoTHubMQTTClient is disconnected"
)
@pytest.mark.parametrize("graceful_exit", graceful_exit_params)
async def test_context_manager_exit_while_disconnected(
self, session, arbitrary_exception, graceful_exit
):
assert session._mqtt_client.enable_direct_method_request_receive.await_count == 0
assert session._mqtt_client.disable_direct_method_request_receive.await_count == 0
try:
async with session.direct_method_requests():
assert session._mqtt_client.enable_direct_method_request_receive.await_count == 1
assert session._mqtt_client.disable_direct_method_request_receive.await_count == 0
session._mqtt_client.connected = False
if not graceful_exit:
raise arbitrary_exception
except type(arbitrary_exception):
pass
assert session._mqtt_client.enable_direct_method_request_receive.await_count == 1
assert session._mqtt_client.disable_direct_method_request_receive.await_count == 0
@pytest.mark.it( @pytest.mark.it(
"Yields an AsyncGenerator that yields the direct method requests yielded by the IoTHubMQTTClient's incoming direct method request message generator" "Yields an AsyncGenerator that yields the direct method requests yielded by the IoTHubMQTTClient's incoming direct method request message generator"
) )
@ -2268,6 +2320,29 @@ class TestIoTHubSessionDesiredPropertyUpdates:
assert session._mqtt_client.enable_twin_patch_receive.await_count == 1 assert session._mqtt_client.enable_twin_patch_receive.await_count == 1
assert session._mqtt_client.disable_twin_patch_receive.await_count == 1 assert session._mqtt_client.disable_twin_patch_receive.await_count == 1
@pytest.mark.it(
"Does not attempt to disable twin patch receive upon exit if IoTHubMQTTClient is disconnected"
)
@pytest.mark.parametrize("graceful_exit", graceful_exit_params)
async def test_context_manager_exit_while_disconnected(
self, session, arbitrary_exception, graceful_exit
):
assert session._mqtt_client.enable_twin_patch_receive.await_count == 0
assert session._mqtt_client.disable_twin_patch_receive.await_count == 0
try:
async with session.desired_property_updates():
assert session._mqtt_client.enable_twin_patch_receive.await_count == 1
assert session._mqtt_client.disable_twin_patch_receive.await_count == 0
session._mqtt_client.connected = False
if not graceful_exit:
raise arbitrary_exception
except type(arbitrary_exception):
pass
assert session._mqtt_client.enable_twin_patch_receive.await_count == 1
assert session._mqtt_client.disable_twin_patch_receive.await_count == 0
@pytest.mark.it( @pytest.mark.it(
"Yields an AsyncGenerator that yields the desired property patches yielded by the IoTHubMQTTClient's incoming twin patch generator" "Yields an AsyncGenerator that yields the desired property patches yielded by the IoTHubMQTTClient's incoming twin patch generator"
) )