зеркало из
1
0
Форкнуть 0

marshal all calls into a single pipeline thread (#133)

* marshal all calls into a single pipeline thread

* incorporate review feedback

* code review feedback and tests

* flake8

* PR feedback

* Automatic LF/CRLF handling regardless of user's .gitconfig settings
This commit is contained in:
Bert Kleewein 2019-07-22 06:03:27 -07:00 коммит произвёл GitHub
Родитель cb0c5f7e1e
Коммит 1206459730
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
28 изменённых файлов: 802 добавлений и 101 удалений

25
.gitattributes поставляемый Normal file
Просмотреть файл

@ -0,0 +1,25 @@
# Default behavior: if Git thinks a file is text (as opposed to binary), it
# will normalize line endings to LF in the repository, but convert to your
# platform's native line endings on checkout (e.g., CRLF for Windows).
* text=auto
# Explicitly declare text files you want to always be normalized and converted
# to native line endings on checkout. E.g.,
*.md text=auto
*.json text=auto
*.ps1 text=auto
# Declare files that will always have LF line endings on checkout. E.g.,
*.sh text eol=lf
# Denote all files that should not have line endings normalized, should not be
# merged, and should not show in a textual diff.
*.docm binary
*.docx binary
*.ico binary
*.lib binary
*.png binary
*.pptx binary
*.snk binary
*.vsdx binary
*.xps binary

Просмотреть файл

@ -56,6 +56,7 @@ class MQTTTransport(object):
self._mqtt_client = mqtt.Client(
client_id=self._client_id, clean_session=False, protocol=mqtt.MQTTv311
)
self._mqtt_client.enable_logger(logging.getLogger("paho"))
def on_connect(client, userdata, flags, rc):
logger.info("connected with result code: {}".format(rc))

Просмотреть файл

@ -6,12 +6,15 @@
import logging
import sys
from . import pipeline_thread
from azure.iot.device.common import unhandled_exceptions
from six.moves import queue
logger = logging.getLogger(__name__)
@pipeline_thread.runs_on_pipeline_thread
def run_ops_in_serial(stage, *args, **kwargs):
"""
Run the operations passed in *args in a serial manner, such that each operation waits for the
@ -57,9 +60,11 @@ def run_ops_in_serial(stage, *args, **kwargs):
if not callback:
raise TypeError("callback is required")
@pipeline_thread.runs_on_pipeline_thread
def on_last_op_done(last_op):
if finally_op:
@pipeline_thread.runs_on_pipeline_thread
def on_finally_done(finally_op):
logger.info(
"{}({}):run_ops_serial: finally_op done.".format(stage.name, finally_op.name)
@ -85,7 +90,7 @@ def run_ops_in_serial(stage, *args, **kwargs):
),
exc_info=e,
)
stage.pipeline_root.unhandled_error_handler(e)
unhandled_exceptions.exception_caught_in_background_thread(e)
finally_op.callback = on_finally_done
logger.info(
@ -107,8 +112,9 @@ def run_ops_in_serial(stage, *args, **kwargs):
),
exc_info=e,
)
stage.pipeline_root.unhandled_error_handler(e)
unhandled_exceptions.exception_caught_in_background_thread(e)
@pipeline_thread.runs_on_pipeline_thread
def on_op_done(completed_op):
logger.info(
"{}({}):run_ops_serial: completed. {} items left".format(
@ -146,6 +152,7 @@ def run_ops_in_serial(stage, *args, **kwargs):
pass_op_to_next_stage(stage, first_op)
@pipeline_thread.runs_on_pipeline_thread
def delegate_to_different_op(stage, original_op, new_op):
"""
Continue an operation using a new operation. This means that the new operation
@ -182,6 +189,7 @@ def delegate_to_different_op(stage, original_op, new_op):
logger.info("{}({}): continuing with {} op".format(stage.name, original_op.name, new_op.name))
@pipeline_thread.runs_on_pipeline_thread
def new_op_complete(op):
logger.info(
"{}({}): completing with result from {}".format(
@ -195,6 +203,7 @@ def delegate_to_different_op(stage, original_op, new_op):
pass_op_to_next_stage(stage, new_op)
@pipeline_thread.runs_on_pipeline_thread
def pass_op_to_next_stage(stage, op):
"""
Helper function to continue a given operation by passing it to the next stage
@ -220,6 +229,7 @@ def pass_op_to_next_stage(stage, op):
stage.next.run_op(op)
@pipeline_thread.runs_on_pipeline_thread
def complete_op(stage, op):
"""
Helper function to complete an operation by calling its callback function thus
@ -240,9 +250,10 @@ def complete_op(stage, op):
),
exc_info=e,
)
stage.pipeline_root.unhandled_error_handler(e)
unhandled_exceptions.exception_caught_in_background_thread(e)
@pipeline_thread.runs_on_pipeline_thread
def pass_event_to_previous_stage(stage, event):
"""
Helper function to pass an event to the previous stage of the pipeline. This is the default
@ -259,4 +270,4 @@ def pass_event_to_previous_stage(stage, event):
error = NotImplementedError(
"{} unhandled at {} stage with no previous stage".format(event.name, stage.name)
)
stage.pipeline_root.unhandled_error_handler(error)
unhandled_exceptions.exception_caught_in_background_thread(error)

Просмотреть файл

@ -12,6 +12,8 @@ from six.moves import queue
from . import pipeline_events_base
from . import pipeline_ops_base
from . import operation_flow
from . import pipeline_thread
from azure.iot.device.common import unhandled_exceptions
logger = logging.getLogger(__name__)
@ -75,6 +77,7 @@ class PipelineStage(object):
self.previous = None
self.pipeline_root = None
@pipeline_thread.runs_on_pipeline_thread
def run_op(self, op):
"""
Run the given operation. This is the public function that outside callers would call to run an
@ -111,6 +114,7 @@ class PipelineStage(object):
"""
pass
@pipeline_thread.runs_on_pipeline_thread
def handle_pipeline_event(self, event):
"""
Handle a pipeline event that arrives from the stage below this stage. Derived
@ -123,11 +127,9 @@ class PipelineStage(object):
try:
self._handle_pipeline_event(event)
except Exception as e:
logger.error(
msg="Error in %s._handle_pipeline_event call".format(self.name), exc_info=e
)
self.pipeline_root.unhandled_error_handler(e)
unhandled_exceptions.exception_caught_in_background_thread(e)
@pipeline_thread.runs_on_pipeline_thread
def _handle_pipeline_event(self, event):
"""
Handle a pipeline event that arrives from the stage below this stage. This
@ -138,6 +140,7 @@ class PipelineStage(object):
"""
operation_flow.pass_event_to_previous_stage(self, event)
@pipeline_thread.runs_on_pipeline_thread
def on_connected(self):
"""
Called by lower layers when the protocol client connects
@ -145,6 +148,7 @@ class PipelineStage(object):
if self.previous:
self.previous.on_connected()
@pipeline_thread.runs_on_pipeline_thread
def on_disconnected(self):
"""
Called by lower layers when the protocol client disconnects
@ -170,6 +174,11 @@ class PipelineRootStage(PipelineStage):
super(PipelineRootStage, self).__init__()
self.on_pipeline_event = None
def run_op(self, op):
op.callback = pipeline_thread.invoke_on_callback_thread_nowait(op.callback)
pipeline_thread.invoke_on_pipeline_thread(super(PipelineRootStage, self).run_op)(op)
@pipeline_thread.runs_on_pipeline_thread
def _run_op(self, op):
"""
run the operation. At the root, the only thing to do is to pass the operation
@ -196,16 +205,7 @@ class PipelineRootStage(PipelineStage):
new_next_stage.pipeline_root = self
return self
def unhandled_error_handler(self, error):
"""
Handler for errors that happen which cannot be tied to a specific operation.
This is still a tentative implimentation and masy be replaced by
some other mechanism as details on behavior are finalized.
"""
# TODO: decide how to pass this error to the app
# TODO: if there's an error in the app handler, print it and exit
pass
@pipeline_thread.runs_on_pipeline_thread
def _handle_pipeline_event(self, event):
"""
Override of the PipelineEvent handler. Because this is the root of the pipeline,
@ -241,6 +241,7 @@ class EnsureConnectionStage(PipelineStage):
self.queue = queue.Queue()
self.blocked = False
@pipeline_thread.runs_on_pipeline_thread
def _run_op(self, op):
# If this stage is currently blocked (because we're waiting for a connection
# to complete, we queue up all operations until after the connect completes.
@ -290,6 +291,7 @@ class EnsureConnectionStage(PipelineStage):
else:
operation_flow.pass_op_to_next_stage(self, op)
@pipeline_thread.runs_on_pipeline_thread
def _block(self, op):
"""
block this stage while we're waiting for the connection to complete.
@ -297,6 +299,7 @@ class EnsureConnectionStage(PipelineStage):
logger.info("{}({}): enabling block".format(self.name, op.name))
self.blocked = True
@pipeline_thread.runs_on_pipeline_thread
def _unblock(self, op, error):
"""
Unblock this stage after the connection is complete. This also means
@ -329,6 +332,7 @@ class EnsureConnectionStage(PipelineStage):
)
self.run_op(op_to_release)
@pipeline_thread.runs_on_pipeline_thread
def _do_connect(self, op):
"""
Start connecting the protocol client in response to some operation (which may or may not be a Connect operation)
@ -343,6 +347,7 @@ class EnsureConnectionStage(PipelineStage):
self.queue.put_nowait(op)
# function that gets called after we're connected.
@pipeline_thread.runs_on_pipeline_thread
def on_connected(op_connect):
logger.info("{}({}): connection is complete".format(self.name, op.name))
# if we're connecting because some layer above us asked us to connect, we complete that operation
@ -361,10 +366,12 @@ class EnsureConnectionStage(PipelineStage):
self, pipeline_ops_base.ConnectOperation(callback=on_connected)
)
@pipeline_thread.runs_on_pipeline_thread
def on_connected(self):
self.connected = True
PipelineStage.on_connected(self)
@pipeline_thread.runs_on_pipeline_thread
def on_disconnected(self):
self.connected = False
PipelineStage.on_disconnected(self)
@ -381,6 +388,7 @@ class CoordinateRequestAndResponseStage(PipelineStage):
super(CoordinateRequestAndResponseStage, self).__init__()
self.pending_responses = {}
@pipeline_thread.runs_on_pipeline_thread
def _run_op(self, op):
if isinstance(op, pipeline_ops_base.SendIotRequestAndWaitForResponseOperation):
# Convert SendIotRequestAndWaitForResponseOperation operation into a SendIotRequestOperation operation
@ -390,6 +398,7 @@ class CoordinateRequestAndResponseStage(PipelineStage):
request_id = str(uuid.uuid4())
@pipeline_thread.runs_on_pipeline_thread
def on_send_request_done(send_request_op):
logger.info(
"{}({}): Finished sending {} request to {} resource {}".format(
@ -433,6 +442,7 @@ class CoordinateRequestAndResponseStage(PipelineStage):
else:
operation_flow.pass_op_to_next_stage(self, op)
@pipeline_thread.runs_on_pipeline_thread
def _handle_pipeline_event(self, event):
if isinstance(event, pipeline_events_base.IotResponseEvent):
# match IotResponseEvent events to the saved dictionary of SendIotRequestAndWaitForResponseOperation

Просмотреть файл

@ -11,6 +11,7 @@ from . import (
pipeline_ops_mqtt,
pipeline_events_mqtt,
operation_flow,
pipeline_thread,
)
from azure.iot.device.common.mqtt_transport import MQTTTransport
@ -24,6 +25,7 @@ class MQTTClientStage(PipelineStage):
is not in the MQTT group of operations, but can only be run at the protocol level.
"""
@pipeline_thread.runs_on_pipeline_thread
def _run_op(self, op):
if isinstance(op, pipeline_ops_mqtt.SetMQTTConnectionArgsOperation):
# pipeline_ops_mqtt.SetMQTTConnectionArgsOperation is where we create our MQTTTransport object and set
@ -62,6 +64,7 @@ class MQTTClientStage(PipelineStage):
elif isinstance(op, pipeline_ops_base.ConnectOperation):
logger.info("{}({}): conneting".format(self.name, op.name))
@pipeline_thread.invoke_on_pipeline_thread_nowait
def on_connected():
logger.info("{}({}): on_connected. completing op.".format(self.name, op.name))
self.transport.on_mqtt_connected = self.on_connected
@ -106,6 +109,7 @@ class MQTTClientStage(PipelineStage):
elif isinstance(op, pipeline_ops_base.ReconnectOperation):
logger.info("{}({}): reconnecting".format(self.name, op.name))
@pipeline_thread.invoke_on_pipeline_thread_nowait
def on_connected():
logger.info("{}({}): on_connected. completing op.".format(self.name, op.name))
self.transport.on_mqtt_connected = self.on_connected
@ -123,6 +127,7 @@ class MQTTClientStage(PipelineStage):
elif isinstance(op, pipeline_ops_base.DisconnectOperation):
logger.info("{}({}): disconnecting".format(self.name, op.name))
@pipeline_thread.invoke_on_pipeline_thread_nowait
def on_disconnected():
logger.info("{}({}): on_disconnected. completing op.".format(self.name, op.name))
self.transport.on_mqtt_disconnected = self.on_disconnected
@ -140,6 +145,7 @@ class MQTTClientStage(PipelineStage):
elif isinstance(op, pipeline_ops_mqtt.MQTTPublishOperation):
logger.info("{}({}): publishing on {}".format(self.name, op.name, op.topic))
@pipeline_thread.invoke_on_pipeline_thread_nowait
def on_published():
logger.info("{}({}): PUBACK received. completing op.".format(self.name, op.name))
operation_flow.complete_op(self, op)
@ -149,6 +155,7 @@ class MQTTClientStage(PipelineStage):
elif isinstance(op, pipeline_ops_mqtt.MQTTSubscribeOperation):
logger.info("{}({}): subscribing to {}".format(self.name, op.name, op.topic))
@pipeline_thread.invoke_on_pipeline_thread_nowait
def on_subscribed():
logger.info("{}({}): SUBACK received. completing op.".format(self.name, op.name))
operation_flow.complete_op(self, op)
@ -158,6 +165,7 @@ class MQTTClientStage(PipelineStage):
elif isinstance(op, pipeline_ops_mqtt.MQTTUnsubscribeOperation):
logger.info("{}({}): unsubscribing from {}".format(self.name, op.name, op.topic))
@pipeline_thread.invoke_on_pipeline_thread_nowait
def on_unsubscribed():
logger.info("{}({}): UNSUBACK received. completing op.".format(self.name, op.name))
operation_flow.complete_op(self, op)
@ -167,6 +175,7 @@ class MQTTClientStage(PipelineStage):
else:
operation_flow.pass_op_to_next_stage(self, op)
@pipeline_thread.invoke_on_pipeline_thread_nowait
def _on_message_received(self, topic, payload):
"""
Handler that gets called by the protocol library when an incoming message arrives.
@ -176,3 +185,11 @@ class MQTTClientStage(PipelineStage):
stage=self,
event=pipeline_events_mqtt.IncomingMQTTMessageEvent(topic=topic, payload=payload),
)
@pipeline_thread.invoke_on_pipeline_thread_nowait
def on_connected(self):
super(MQTTClientStage, self).on_connected()
@pipeline_thread.invoke_on_pipeline_thread_nowait
def on_disconnected(self):
super(MQTTClientStage, self).on_disconnected()

Просмотреть файл

@ -0,0 +1,196 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import functools
import logging
import threading
import traceback
from multiprocessing.pool import ThreadPool
from concurrent.futures import ThreadPoolExecutor
from azure.iot.device.common import unhandled_exceptions
logger = logging.getLogger(__name__)
"""
This module contains decorators that are used to marshal code into pipeline and
callback threads and to assert that code is being called in the correct thread.
The intention of these decorators is to ensure the following:
1. All pipeline functions execute in a single thread, known as the "pipeline
thread". The `invoke_on_pipeline_thread` and `invoke_on_pipeline_thread_nowait`
decorators cause the decorated function to run on the pipeline thread.
2. If the pipeline thread is busy running a different function, the invoke
decorators will wait until that function is complete before invoking another
function on that thread.
3. There is a different thread which is used for callbacks into user code, known
as the the "callback thread". This is not meant for callbacks into pipeline
code. Those callbacks should still execute on the pipeline thread. The
`invoke_on_callback_thread_nowait` decorator is used to ensure that callbacks
execute on the callback thread.
4. Decorators which cause thread switches are used only when necessary. The
pipeline thread is only entered in places where we know that external code is
calling into the pipeline (such as a client API call or a callback from a
third-party library). Likewise, the callback thread is only entered in places
where we know that the pipeline is calling back into client code.
5. Exceptions raised from the pipeline thread are still able to be caught by
the function which entered the pipeline thread.
5. Calls into the pipeline thread can either block or not block. Blocking is used
for cases where the caller needs a return value from the pipeline or is
expecting to handle any errors raised from the pipeline thread. Blocking is
not used when the code calling into the pipeline is not waiting for a response
and is not expecting to handle any exceptions, such as protocol library
handlers which call into the pipeline to deliver protocol messages.
6. Calls into the callback thread could theoretically block, but we currently
only have decorators which enter the callback thread without blocking. This
is done to ensure that client code does not execute on the pipeline thread and
also to ensure that the pipline thread is not blocked while waiting for client
code to execute.
These decorators use concurrent.futures.Future and the ThreadPoolExecutor because:
1. The thread pooling with a pool size of 1 gives us a single thread to run all
pipeline operations and a different (single) thread to run all callbacks. If
the code attempts to run a second pipeline operation (or callback) while a
different one is running, the ThreadPoolExecutor will queue the code until the
first call is completed.
2. The concurent.futures.Future object properly handles both Exception and
BaseException errors, re-raising them when the Future.result method is called.
threading.Thread.get() was not an option because it doesn't re-raise
BaseException errors when Thread.get is called.
3. concurrent.futures is available as a backport to 2.7.
"""
_executors = {}
def _get_named_executor(thread_name):
"""
Get a ThreadPoolExecutor object with the given name. If no such executor exists,
this function will create on with a single worker and assign it to the provided
name.
"""
global _executors
if thread_name not in _executors:
logger.info("Creating {} executor".format(thread_name))
_executors[thread_name] = ThreadPoolExecutor(max_workers=1)
return _executors[thread_name]
def _invoke_on_executor_thread(func, thread_name, block=True):
"""
Return wrapper to run the function on a given thread. If block==False,
the call returns immediately without waiting for the decorated function to complete.
If block==True, the call waits for the decorated function to complete before returning.
"""
# Mocks on py27 don't have a __name__ attribute. Use str() if you can't use __name__
try:
function_name = func.__name__
function_has_name = True
except AttributeError:
function_name = str(func)
function_has_name = False
def wrapper(*args, **kwargs):
if threading.current_thread().name is not thread_name:
logger.info("Starting {} in {} thread".format(function_name, thread_name))
def thread_proc():
threading.current_thread().name = thread_name
try:
return func(*args, **kwargs)
except Exception as e:
if not block:
unhandled_exceptions.exception_caught_in_background_thread(e)
except BaseException:
if not block:
logger.error("Unhandled exception in background thread")
logger.error(
"This may cause the background thread to abort and may result in system instability."
)
traceback.print_exc()
raise
# TODO: add a timeout here and throw exception on failure
future = _get_named_executor(thread_name).submit(thread_proc)
if block:
return future.result()
else:
return future
else:
logger.debug("Already in {} thread for {}".format(thread_name, function_name))
return func(*args, **kwargs)
# Silly hack: On 2.7, we can't use @functools.wraps on callables don't have a __name__ attribute
# attribute(like MagicMock object), so we only do it when we have a name. functools.update_wrapper
# below is the same as using the @functools.wraps(func) decorator on the wrapper function above.
if function_has_name:
return functools.update_wrapper(wrapped=func, wrapper=wrapper)
else:
wrapper.__wrapped__ = func # needed by tests
return wrapper
def invoke_on_pipeline_thread(func):
"""
Run the decorated function on the pipeline thread.
"""
return _invoke_on_executor_thread(func=func, thread_name="pipeline")
def invoke_on_pipeline_thread_nowait(func):
"""
Run the decorated function on the pipeline thread, but don't wait for it to complete
"""
return _invoke_on_executor_thread(func=func, thread_name="pipeline", block=False)
def invoke_on_callback_thread_nowait(func):
"""
Run the decorated function on the callback thread, but don't wait for it to complete
"""
return _invoke_on_executor_thread(func=func, thread_name="callback", block=False)
def _assert_executor_thread(func, thread_name):
"""
Decorator which asserts that the given function only gets called inside the given
thread.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
assert (
threading.current_thread().name == thread_name
), """
Function {function_name} is not running inside {thread_name} thread.
It should be. You should use invoke_on_{thread_name}_thread(_nowait) to enter the
{thread_name} thread before calling this function. If you're hitting this from
inside a test function, you may need to add the fake_pipeline_thread fixture to
your test. (grep for apply_fake_pipeline_thread) """.format(
function_name=func.__name__, thread_name=thread_name
)
return func(*args, **kwargs)
return wrapper
def runs_on_pipeline_thread(func):
"""
Decorator which marks a function as only running inside the pipeline thread.
"""
return _assert_executor_thread(func=func, thread_name="pipeline")

Просмотреть файл

@ -0,0 +1,27 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
logger = logging.getLogger(__name__)
def exception_caught_in_background_thread(e):
"""
Function which handled exceptions that are caught in background thread. This is
typically called from the callback thread inside the pipeline. These exceptions
need special handling because callback functions are typically called inside a
non-application thread in response to non-user-initiated actions, so there's
nobody else to catch them.
This function gets called from inside an arbitrary thread context, so code that
runs from this function should be limited to the bare minumum.
:param Error e: Exception object raised from inside a background thread
"""
# @FUTURE: We should add a mechanism which allows applications to receive these
# exceptions so they can respond accordingly
logger.error(msg="Exception caught in background thread. Unable to handle.", exc_info=e)

Просмотреть файл

@ -6,7 +6,12 @@
import json
import logging
from azure.iot.device.common.pipeline import pipeline_ops_base, PipelineStage, operation_flow
from azure.iot.device.common.pipeline import (
pipeline_ops_base,
PipelineStage,
operation_flow,
pipeline_thread,
)
from . import pipeline_ops_iothub
from . import constant
@ -27,6 +32,7 @@ class UseAuthProviderStage(PipelineStage):
All other operations are passed down.
"""
@pipeline_thread.runs_on_pipeline_thread
def _run_op(self, op):
def pipeline_ops_done(completed_op):
op.error = completed_op.error
@ -78,6 +84,7 @@ class HandleTwinOperationsStage(PipelineStage):
protocol-specific receive event into an IotResponseEvent event.
"""
@pipeline_thread.runs_on_pipeline_thread
def _run_op(self, op):
def map_twin_error(original_op, twin_op):
if twin_op.error:

Просмотреть файл

@ -13,6 +13,7 @@ from azure.iot.device.common.pipeline import (
pipeline_events_mqtt,
PipelineStage,
operation_flow,
pipeline_thread,
)
from azure.iot.device.iothub.models import Message, MethodRequest
from . import constant, pipeline_ops_iothub, pipeline_events_iothub, mqtt_topic_iothub
@ -30,6 +31,7 @@ class IoTHubMQTTConverterStage(PipelineStage):
super(IoTHubMQTTConverterStage, self).__init__()
self.feature_to_topic = {}
@pipeline_thread.runs_on_pipeline_thread
def _run_op(self, op):
if isinstance(op, pipeline_ops_iothub.SetAuthProviderArgsOperation):
@ -127,6 +129,7 @@ class IoTHubMQTTConverterStage(PipelineStage):
# All other operations get passed down
operation_flow.pass_op_to_next_stage(self, op)
@pipeline_thread.runs_on_pipeline_thread
def _set_topic_names(self, device_id, module_id):
"""
Build topic names based on the device_id and module_id passed.
@ -144,6 +147,7 @@ class IoTHubMQTTConverterStage(PipelineStage):
constant.TWIN_PATCHES: (mqtt_topic_iothub.get_twin_patch_topic_for_subscribe()),
}
@pipeline_thread.runs_on_pipeline_thread
def _handle_pipeline_event(self, event):
"""
Pipeline Event handler function to convert incoming MQTT messages into the appropriate IoTHub

Просмотреть файл

@ -4,7 +4,7 @@
# license information.
# --------------------------------------------------------------------------
from azure.iot.device.common.pipeline import pipeline_ops_base, operation_flow
from azure.iot.device.common.pipeline import pipeline_ops_base, operation_flow, pipeline_thread
from azure.iot.device.common.pipeline.pipeline_stages_base import PipelineStage
from . import pipeline_ops_provisioning
@ -28,6 +28,7 @@ class UseSecurityClientStage(PipelineStage):
All other operations are passed down.
"""
@pipeline_thread.runs_on_pipeline_thread
def _run_op(self, op):
def pipeline_ops_done(completed_op):
op.error = completed_op.error

Просмотреть файл

@ -11,6 +11,7 @@ from azure.iot.device.common.pipeline import (
pipeline_ops_mqtt,
pipeline_events_mqtt,
operation_flow,
pipeline_thread,
)
from azure.iot.device.common.pipeline.pipeline_stages_base import PipelineStage
from azure.iot.device.provisioning.pipeline import constant, mqtt_topic
@ -32,6 +33,7 @@ class ProvisioningMQTTConverterStage(PipelineStage):
super(ProvisioningMQTTConverterStage, self).__init__()
self.action_to_topic = {}
@pipeline_thread.runs_on_pipeline_thread
def _run_op(self, op):
if isinstance(
@ -103,6 +105,7 @@ class ProvisioningMQTTConverterStage(PipelineStage):
# All other operations get passed down
operation_flow.pass_op_to_next_stage(self, op)
@pipeline_thread.runs_on_pipeline_thread
def _handle_pipeline_event(self, event):
"""
Pipeline Event handler function to convert incoming MQTT messages into the appropriate DPS

Просмотреть файл

@ -63,6 +63,7 @@ setup(
"requests>=2.20.0,<3.0.0",
"requests-unixsocket>=0.1.5,<1.0.0",
"janus>=0.4.0,<1.0.0;python_version>='3.5'",
"futures;python_version == '2.7'",
],
extras_require={":python_version<'3.0'": ["azure-iot-nspkg>=1.0.1"]},
python_requires=">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3*, <4",

Просмотреть файл

@ -14,4 +14,7 @@ from tests.common.pipeline.fixtures import (
op3,
finally_op,
new_op,
fake_pipeline_thread,
fake_non_pipeline_thread,
unhandled_error_handler,
)

Просмотреть файл

@ -4,8 +4,14 @@
# license information.
# --------------------------------------------------------------------------
import pytest
import threading
from tests.common.pipeline import helpers
from azure.iot.device.common.pipeline import pipeline_events_base, pipeline_ops_base
from azure.iot.device.common import unhandled_exceptions
from azure.iot.device.common.pipeline import (
pipeline_events_base,
pipeline_ops_base,
pipeline_thread,
)
@pytest.fixture
@ -39,35 +45,67 @@ class FakeOperation(pipeline_ops_base.PipelineOperation):
@pytest.fixture
def op():
op = FakeOperation()
def op(callback):
op = FakeOperation(callback=callback)
op.name = "op"
return op
@pytest.fixture
def op2():
op = FakeOperation()
def op2(callback):
op = FakeOperation(callback=callback)
op.name = "op2"
return op
@pytest.fixture
def op3():
op = FakeOperation()
def op3(callback):
op = FakeOperation(callback=callback)
op.name = "op3"
return op
@pytest.fixture
def finally_op():
op = FakeOperation()
def finally_op(callback):
op = FakeOperation(callback=callback)
op.name = "finally_op"
return op
@pytest.fixture
def new_op():
op = FakeOperation()
def new_op(callback):
op = FakeOperation(callback=callback)
op.name = "new_op"
return op
@pytest.fixture
def fake_pipeline_thread():
"""
This fixture mocks out the thread name so that the pipeline decorators
use to assert that you are in a pipeline thread.
"""
this_thread = threading.current_thread()
old_name = this_thread.name
this_thread.name = "pipeline"
yield
this_thread.name = old_name
@pytest.fixture
def fake_non_pipeline_thread():
"""
This fixture sets thread name to something other than "pipeline" to force asserts
"""
this_thread = threading.current_thread()
old_name = this_thread.name
this_thread.name = "not pipeline"
yield
this_thread.name = old_name
@pytest.fixture
def unhandled_error_handler(mocker):
return mocker.patch.object(unhandled_exceptions, "exception_caught_in_background_thread")

Просмотреть файл

@ -6,6 +6,7 @@
import inspect
import pytest
import functools
from threading import Event
from azure.iot.device.common.pipeline import (
pipeline_events_base,
pipeline_ops_base,
@ -95,6 +96,13 @@ def make_mock_stage(mocker, stage_to_make):
def assert_callback_succeeded(op, callback=None):
if not callback:
callback = op.callback
try:
# if the callback has a __wrapped__ attribute, that means that the
# pipeline added a wrapper around the callback, so we want to look
# at the original function instead of the wrapped function.
callback = callback.__wrapped__
except AttributeError:
pass
assert callback.call_count == 1
callback_arg = callback.call_args[0][0]
assert callback_arg == op
@ -104,6 +112,13 @@ def assert_callback_succeeded(op, callback=None):
def assert_callback_failed(op, callback=None, error=None):
if not callback:
callback = op.callback
try:
# if the callback has a __wrapped__ attribute, that means that the
# pipeline added a wrapper around the callback, so we want to look
# at the original function instead of the wrapped function.
callback = callback.__wrapped__
except AttributeError:
pass
assert callback.call_count == 1
callback_arg = callback.call_args[0][0]
assert callback_arg == op
@ -124,11 +139,45 @@ class UnhandledException(BaseException):
def get_arg_count(fn):
"""
return the number of arguments (args) passed into a
particular function. Returned value not include kwargs.
particular function. Returned value does not include kwargs.
"""
return len(getargspec(fn).args)
try:
# if __wrapped__ is set, we're looking at a decorated function
# Functools.wraps doesn't copy arg metadata, so we need to
# get argument count from the wrapped function instead.
return len(getargspec(fn.__wrapped__).args)
except AttributeError:
return len(getargspec(fn).args)
def make_mock_op_or_event(cls):
args = [None for i in (range(get_arg_count(cls.__init__) - 1))]
return cls(*args)
def add_mock_method_waiter(obj, method_name):
"""
For mock methods, add "wait_for_xxx_to_be_called" and "wait_for_xxx_to_not_be_called"
helper functions on the object. This is very handy for methods that get called by
another thread, when you want your test functions to wait until the other thread is
able to call the method without using a sleep call.
"""
method_called = Event()
def signal_method_called(*args, **kwargs):
method_called.set()
def wait_for_method_to_be_called():
method_called.wait(0.1)
assert method_called.isSet()
method_called.clear()
def wait_for_method_to_not_be_called():
method_called.wait(0.1)
assert not method_called.isSet()
getattr(obj, method_name).side_effect = signal_method_called
setattr(obj, "wait_for_{}_to_be_called".format(method_name), wait_for_method_to_be_called)
setattr(
obj, "wait_for_{}_to_not_be_called".format(method_name), wait_for_method_to_not_be_called
)

Просмотреть файл

@ -5,20 +5,35 @@
# --------------------------------------------------------------------------
import logging
import pytest
import inspect
import threading
import concurrent.futures
from tests.common.pipeline.helpers import (
all_except,
make_mock_stage,
make_mock_op_or_event,
assert_callback_failed,
UnhandledException,
get_arg_count,
add_mock_method_waiter,
)
from azure.iot.device.common.pipeline.pipeline_stages_base import PipelineStage
from tests.common.pipeline.pipeline_data_object_test import add_instantiation_test
from azure.iot.device.common.pipeline import pipeline_thread
logging.basicConfig(level=logging.INFO)
def add_base_pipeline_stage_tests(cls, module, all_ops, handled_ops, all_events, handled_events):
def add_base_pipeline_stage_tests(
cls,
module,
all_ops,
handled_ops,
all_events,
handled_events,
methods_that_enter_pipeline_thread=[],
methods_that_can_run_in_any_thread=[],
):
"""
Add all of the "basic" tests for validating a pipeline stage. This includes tests for
instantiation and tests for properly handling "unhandled" operations and events".
@ -33,6 +48,12 @@ def add_base_pipeline_stage_tests(cls, module, all_ops, handled_ops, all_events,
add_unknown_events_tests(
cls=cls, module=module, all_events=all_events, handled_events=handled_events
)
add_pipeline_thread_tests(
cls=cls,
module=module,
methods_that_enter_pipeline_thread=methods_that_enter_pipeline_thread,
methods_that_can_run_in_any_thread=methods_that_can_run_in_any_thread,
)
def add_unknown_ops_tests(cls, module, all_ops, handled_ops):
@ -44,13 +65,13 @@ def add_unknown_ops_tests(cls, module, all_ops, handled_ops):
unknown_ops = all_except(all_items=all_ops, items_to_exclude=handled_ops)
@pytest.mark.describe("{} - .run_op() -- unknown and unhandled operations".format(cls.__name__))
@pytest.mark.parametrize("op_cls", unknown_ops)
class LocalTestObject(object):
@pytest.fixture
def op(self, op_cls, callback):
op = make_mock_op_or_event(op_cls)
op.callback = callback
op.action = "pend"
add_mock_method_waiter(op, "callback")
return op
@pytest.fixture
@ -58,26 +79,32 @@ def add_unknown_ops_tests(cls, module, all_ops, handled_ops):
return make_mock_stage(mocker=mocker, stage_to_make=cls)
@pytest.mark.it("Passes unknown operation to next stage")
@pytest.mark.parametrize("op_cls", unknown_ops)
def test_passes_op_to_next_stage(self, op_cls, op, stage):
stage.run_op(op)
assert stage.next.run_op.call_count == 1
assert stage.next.run_op.call_args[0][0] == op
@pytest.mark.it("Fails unknown operation if there is no next stage")
@pytest.mark.parametrize("op_cls", unknown_ops)
def test_passes_op_with_no_next_stage(self, op_cls, op, stage):
stage.next = None
stage.run_op(op)
op.wait_for_callback_to_be_called()
assert_callback_failed(op=op)
@pytest.mark.it("Catches Exceptions raised when passing unknown operation to next stage")
@pytest.mark.parametrize("op_cls", unknown_ops)
def test_passes_op_to_next_stage_which_throws_exception(self, op_cls, op, stage):
op.action = "exception"
stage.run_op(op)
op.wait_for_callback_to_be_called()
assert_callback_failed(op=op)
@pytest.mark.it(
"Allows BaseExceptions raised when passing unknown operation to next start to propogate"
)
@pytest.mark.parametrize("op_cls", unknown_ops)
def test_passes_op_to_next_stage_which_throws_base_exception(self, op_cls, op, stage):
op.action = "base_exception"
with pytest.raises(UnhandledException):
@ -95,6 +122,9 @@ def add_unknown_events_tests(cls, module, all_events, handled_events):
unknown_events = all_except(all_items=all_events, items_to_exclude=handled_events)
if not unknown_events:
return
@pytest.mark.describe(
"{} - .handle_pipeline_event() -- unknown and unhandled events".format(cls.__name__)
)
@ -122,16 +152,6 @@ def add_unknown_events_tests(cls, module, all_events, handled_events):
stage.previous = previous
return previous
@pytest.fixture
def unhandled_error_handler(self, stage, mocker):
class MockPipelineRootStage(object):
def __init__(self):
self.unhandled_error_handler = mocker.MagicMock()
root = MockPipelineRootStage()
stage.pipeline_root = root
return root.unhandled_error_handler
@pytest.mark.it("Passes unknown event to previous stage")
def test_passes_event_to_previous_stage(self, event_cls, stage, event, previous):
stage.handle_pipeline_event(event)
@ -167,6 +187,81 @@ def add_unknown_events_tests(cls, module, all_events, handled_events):
stage.handle_pipeline_event(event)
assert unhandled_error_handler.call_count == 0
pass
setattr(module, "Test{}UnknownEvents".format(cls.__name__), LocalTestObject)
class ThreadLaunchedError(Exception):
pass
def add_pipeline_thread_tests(
cls, module, methods_that_enter_pipeline_thread, methods_that_can_run_in_any_thread
):
def does_method_assert_pipeline_thread(method_name):
if method_name.startswith("__"):
return False
elif method_name in methods_that_enter_pipeline_thread:
return False
elif method_name in methods_that_can_run_in_any_thread:
return False
else:
return True
methods_that_assert_pipeline_thread = [
x[0]
for x in inspect.getmembers(cls, inspect.isfunction)
if does_method_assert_pipeline_thread(x[0])
]
@pytest.mark.describe("{} - Pipeline threading".format(cls.__name__))
class LocalTestObject(object):
@pytest.fixture
def stage(self):
return cls()
@pytest.mark.parametrize("method_name", methods_that_assert_pipeline_thread)
@pytest.mark.it("Enforces use of the pipeline thread when calling method")
def test_asserts_in_pipeline(self, stage, method_name, fake_non_pipeline_thread):
func = getattr(stage, method_name)
args = [None for i in (range(get_arg_count(func) - 1))]
with pytest.raises(AssertionError):
func(*args)
if methods_that_enter_pipeline_thread:
@pytest.mark.parametrize("method_name", methods_that_enter_pipeline_thread)
@pytest.mark.it("Automatically enters the pipeline thread when calling method")
def test_enters_pipeline(self, mocker, stage, method_name, fake_non_pipeline_thread):
func = getattr(stage, method_name)
args = [None for i in (range(get_arg_count(func) - 1))]
#
# We take a bit of a roundabout way to verify that the functuion enters the
# pipeline executor:
#
# 1. we verify that the method got the pipeline executor
# 2. we verify that the method invoked _something_ on the pipeline executor
#
# It's not perfect, but it's good enough.
#
# We do this because:
# 1. We don't have the exact right args to run the method and we don't want
# to add the complexity to get the right args in this test.
# 2. We can't replace the wrapped method with a mock, AFAIK.
#
pipeline_executor = pipeline_thread._get_named_executor("pipeline")
mocker.patch.object(pipeline_executor, "submit")
pipeline_executor.submit.side_effect = ThreadLaunchedError
mocker.spy(pipeline_thread, "_get_named_executor")
# If the method calls submit on some executor, it will raise a ThreadLaunchedError
with pytest.raises(ThreadLaunchedError):
func(*args)
# now verify that the code got the pipeline executor and verify that it used that
# executor to launch something.
assert pipeline_thread._get_named_executor.call_count == 1
assert pipeline_thread._get_named_executor.call_args[0][0] == "pipeline"
assert pipeline_executor.submit.call_count == 1
setattr(module, "Test{}PipelineThreading".format(cls.__name__), LocalTestObject)

Просмотреть файл

@ -6,6 +6,7 @@
import logging
import pytest
from azure.iot.device.common.pipeline import (
pipeline_thread,
pipeline_stages_base,
pipeline_ops_base,
pipeline_events_base,
@ -26,6 +27,16 @@ from tests.common.pipeline.helpers import (
logging.basicConfig(level=logging.INFO)
# This fixture makes it look like all test in this file tests are running
# inside the pipeline thread. Because this is an autouse fixture, we
# manually add it to the individual test.py files that need it. If,
# instead, we had added it to some conftest.py, it would be applied to
# every tests in every file and we don't want that.
@pytest.fixture(autouse=True)
def apply_fake_pipeline_thread(fake_pipeline_thread):
pass
class MockPipelineStage(pipeline_stages_base.PipelineStage):
def _run_op(self, op):
pass_op_to_next_stage(self, op)
@ -79,13 +90,15 @@ class TestRunOpsSerialOneOpButNoFinallyOp(object):
@pytest.mark.it(
"Handles Exceptions raised in the callback and passes them to the unhandled error handler"
)
def test_callback_throws_exception(self, stage, mocker, fake_exception, op):
def test_callback_throws_exception(
self, stage, mocker, fake_exception, op, unhandled_error_handler
):
callback = mocker.Mock(side_effect=fake_exception)
run_ops_in_serial(stage, op, callback=callback)
assert callback.call_count == 1
assert callback.call_args == mocker.call(op)
assert stage.unhandled_error_handler.call_count == 1
assert stage.unhandled_error_handler.call_args == mocker.call(fake_exception)
assert unhandled_error_handler.call_count == 1
assert unhandled_error_handler.call_args == mocker.call(fake_exception)
@pytest.mark.it("Allows any BaseExceptions raised in the callback to propagate")
def test_callback_throws_base_exception(self, stage, mocker, fake_base_exception, op):
@ -153,13 +166,15 @@ class TestRunOpsSerialOneOpAndFinallyOp(object):
@pytest.mark.it(
"Handles Exceptions raised in the callback and passes them to the unhandled error handler"
)
def test_callback_raises_exception(self, stage, op, finally_op, fake_exception, mocker):
def test_callback_raises_exception(
self, stage, op, finally_op, fake_exception, mocker, unhandled_error_handler
):
callback = mocker.Mock(side_effect=fake_exception)
run_ops_in_serial(stage, op, finally_op=finally_op, callback=callback)
assert callback.call_count == 1
assert callback.call_args == mocker.call(finally_op)
assert stage.unhandled_error_handler.call_count == 1
assert stage.unhandled_error_handler.call_args == mocker.call(fake_exception)
assert unhandled_error_handler.call_count == 1
assert unhandled_error_handler.call_args == mocker.call(fake_exception)
@pytest.mark.it("Allows any BaseExceptions raised in the callback to propagate")
def test_callback_raises_base_exception(
@ -197,13 +212,15 @@ class TestRunOpsSerialThreeOpsButNoFinallyOp(object):
@pytest.mark.it(
"Handles Exceptions raised in the callback and passes them to the unhandled error handler"
)
def test_callback_raises_exception(self, stage, op, op2, op3, fake_exception, mocker):
def test_callback_raises_exception(
self, stage, op, op2, op3, fake_exception, mocker, unhandled_error_handler
):
callback = mocker.Mock(side_effect=fake_exception)
run_ops_in_serial(stage, op, op2, op3, callback=callback)
assert callback.call_count == 1
assert callback.call_args == mocker.call(op3)
assert stage.unhandled_error_handler.call_count == 1
assert stage.unhandled_error_handler.call_args == mocker.call(fake_exception)
assert unhandled_error_handler.call_count == 1
assert unhandled_error_handler.call_args == mocker.call(fake_exception)
@pytest.mark.it("Allows any BaseExceptions raised in the callback to propagate")
def test_callback_raises_base_exception(self, stage, op, op2, op3, fake_base_exception, mocker):
@ -305,12 +322,12 @@ class TestRunOpsSerialThreeOpsAndFinallyOp(object):
"Handles Exceptions raised in the callback and passes them to the unhandled error handler"
)
def test_callback_raises_exception(
self, stage, op, op2, op3, finally_op, fake_exception, mocker
self, stage, op, op2, op3, finally_op, fake_exception, mocker, unhandled_error_handler
):
callback = mocker.Mock(side_effect=fake_exception)
run_ops_in_serial(stage, op, op2, op3, callback=callback, finally_op=finally_op)
assert stage.unhandled_error_handler.call_count == 1
assert stage.unhandled_error_handler.call_args == mocker.call(fake_exception)
assert unhandled_error_handler.call_count == 1
assert unhandled_error_handler.call_args == mocker.call(fake_exception)
@pytest.mark.it("Allows any BaseExceptions raised in the callback to propagate")
def test_callback_raises_base_exception(
@ -505,13 +522,15 @@ class TestCompleteOp(object):
@pytest.mark.it(
"Handles Exceptions raised in operation callback and passes them to the unhandled error handler"
)
def test_op_callback_raises_exception(self, stage, op, fake_exception, mocker):
def test_op_callback_raises_exception(
self, stage, op, fake_exception, mocker, unhandled_error_handler
):
op.callback = mocker.Mock(side_effect=fake_exception)
complete_op(stage, op)
assert op.callback.call_count == 1
assert op.callback.call_args == mocker.call(op)
assert stage.unhandled_error_handler.call_count == 1
assert stage.unhandled_error_handler.call_args == mocker.call(fake_exception)
assert unhandled_error_handler.call_count == 1
assert unhandled_error_handler.call_args == mocker.call(fake_exception)
@pytest.mark.it("Allows any BaseExceptions raised in operation callback to propagate")
def test_op_callback_raises_base_exception(self, stage, op, fake_base_exception, mocker):

Просмотреть файл

@ -6,6 +6,7 @@
import logging
import pytest
import sys
import threading
from azure.iot.device.common.pipeline import (
pipeline_stages_base,
pipeline_ops_base,
@ -26,6 +27,17 @@ logging.basicConfig(level=logging.INFO)
this_module = sys.modules[__name__]
# This fixture makes it look like all test in this file tests are running
# inside the pipeline thread. Because this is an autouse fixture, we
# manually add it to the individual test.py files that need it. If,
# instead, we had added it to some conftest.py, it would be applied to
# every tests in every file and we don't want that.
@pytest.fixture(autouse=True)
def apply_fake_pipeline_thread(fake_pipeline_thread):
pass
pipeline_stage_test.add_base_pipeline_stage_tests(
cls=pipeline_stages_base.EnsureConnectionStage,
module=this_module,
@ -42,6 +54,68 @@ pipeline_stage_test.add_base_pipeline_stage_tests(
handled_events=[],
)
# Workaround for flake8. A class with this name is actually created inside
# add_base_pipeline_stage_test, but flake8 doesn't know that
class TestPipelineRootStagePipelineThreading:
pass
pipeline_stage_test.add_base_pipeline_stage_tests(
cls=pipeline_stages_base.PipelineRootStage,
module=this_module,
all_ops=all_common_ops,
handled_ops=[],
all_events=all_common_events,
handled_events=all_common_events,
methods_that_can_run_in_any_thread=["append_stage", "run_op"],
)
@pytest.mark.it("Calls operation callback in callback thread")
def _test_pipeline_root_runs_callback_in_callback_thread(self, stage, mocker):
# the stage fixture comes from the TestPipelineRootStagePipelineThreading object that
# this test method gets added to, so it's a PipelineRootStage object
stage.pipeline_root = stage
callback_called = threading.Event()
def callback(op):
assert threading.current_thread().name == "callback"
callback_called.set()
op = pipeline_ops_base.ConnectOperation(callback=callback)
stage.run_op(op)
callback_called.wait()
@pytest.mark.it("Runs operation in pipeline thread")
def _test_pipeline_root_runs_operation_in_pipeline_thread(
self, mocker, stage, op, fake_non_pipeline_thread
):
# the stage fixture comes from the TestPipelineRootStagePipelineThreading object that
# this test method gets added to, so it's a PipelineRootStage object
assert threading.current_thread().name != "pipeline"
def mock_run_op(self, op):
print("mock_run_op called")
assert threading.current_thread().name == "pipeline"
op.callback(op)
mock_run_op = mocker.MagicMock(mock_run_op)
stage._run_op = mock_run_op
stage.run_op(op)
assert mock_run_op.call_count == 1
TestPipelineRootStagePipelineThreading.test_runs_callback_in_callback_thread = (
_test_pipeline_root_runs_callback_in_callback_thread
)
TestPipelineRootStagePipelineThreading.test_runs_operation_in_pipeline_thread = (
_test_pipeline_root_runs_operation_in_pipeline_thread
)
pipeline_stage_test.add_base_pipeline_stage_tests(
cls=pipeline_stages_base.CoordinateRequestAndResponseStage,
module=this_module,
@ -171,27 +245,29 @@ class TestCoordinateRequestAndResponseSendIotRequestHandleEvent(object):
@pytest.mark.it(
"Calls the unhandled error handler if there is no previous stage when request_id matches"
)
def test_matching_request_id_with_no_previous_stage(self, stage, op, iot_response):
def test_matching_request_id_with_no_previous_stage(
self, stage, op, iot_response, unhandled_error_handler
):
stage.next.previous = None
operation_flow.pass_event_to_previous_stage(stage.next, iot_response)
assert stage.pipeline_root.unhandled_error_handler.call_count == 1
assert unhandled_error_handler.call_count == 1
@pytest.mark.it(
"Does nothing if an IotResponse with an identical request_id is received a second time"
)
def test_ignores_duplicate_request_id(self, stage, op, iot_response):
def test_ignores_duplicate_request_id(self, stage, op, iot_response, unhandled_error_handler):
operation_flow.pass_event_to_previous_stage(stage.next, iot_response)
assert_callback_succeeded(op=op)
op.callback.reset_mock()
operation_flow.pass_event_to_previous_stage(stage.next, iot_response)
assert op.callback.call_count == 0
assert stage.pipeline_root.unhandled_error_handler.call_count == 0
assert unhandled_error_handler.call_count == 0
@pytest.mark.it(
"Does nothing if an IotResponse with a request_id is received for an operation that returned failure"
)
def test_ignores_request_id_from_failure(self, stage, op, mocker):
def test_ignores_request_id_from_failure(self, stage, op, mocker, unhandled_error_handler):
stage.next._run_op = mocker.MagicMock(side_effect=Exception)
stage.run_op(op)
@ -205,11 +281,11 @@ class TestCoordinateRequestAndResponseSendIotRequestHandleEvent(object):
op.callback.reset_mock()
operation_flow.pass_event_to_previous_stage(stage.next, resp)
assert op.callback.call_count == 0
assert stage.pipeline_root.unhandled_error_handler.call_count == 0
assert unhandled_error_handler.call_count == 0
@pytest.mark.it("Does nothing if an IotResponse with an unknown request_id is received")
def test_ignores_unknown_request_id(self, stage, op, iot_response):
def test_ignores_unknown_request_id(self, stage, op, iot_response, unhandled_error_handler):
iot_response.request_id = fake_request_id
operation_flow.pass_event_to_previous_stage(stage.next, iot_response)
assert op.callback.call_count == 0
assert stage.pipeline_root.unhandled_error_handler.call_count == 0
assert unhandled_error_handler.call_count == 0

Просмотреть файл

@ -25,6 +25,17 @@ from tests.common.pipeline import pipeline_stage_test
logging.basicConfig(level=logging.INFO)
# This fixture makes it look like all test in this file tests are running
# inside the pipeline thread. Because this is an autouse fixture, we
# manually add it to the individual test.py files that need it. If,
# instead, we had added it to some conftest.py, it would be applied to
# every tests in every file and we don't want that.
@pytest.fixture(autouse=True)
def apply_fake_pipeline_thread(fake_pipeline_thread):
pass
this_module = sys.modules[__name__]
fake_client_id = "__fake_client_id__"
@ -57,6 +68,7 @@ pipeline_stage_test.add_base_pipeline_stage_tests(
handled_ops=ops_handled_by_this_stage,
all_events=all_common_events,
handled_events=events_handled_by_this_stage,
methods_that_enter_pipeline_thread=["_on_message_received", "on_connected", "on_disconnected"],
)
@ -150,7 +162,7 @@ def op_set_sas_token(callback):
@pytest.fixture
def op_set_client_certificate(callback):
return pipeline_ops_base.SetClientAuthenticationCertificateOperation(
certificate=fake_certificate
certificate=fake_certificate, callback=callback
)

Просмотреть файл

@ -4,4 +4,12 @@
# license information.
# --------------------------------------------------------------------------
from tests.common.pipeline.fixtures import callback, fake_exception, fake_base_exception, event
from tests.common.pipeline.fixtures import (
callback,
fake_exception,
fake_base_exception,
event,
fake_pipeline_thread,
fake_non_pipeline_thread,
unhandled_error_handler,
)

Просмотреть файл

@ -29,6 +29,16 @@ logging.basicConfig(level=logging.INFO)
this_module = sys.modules[__name__]
# This fixture makes it look like all test in this file tests are running
# inside the pipeline thread. Because this is an autouse fixture, we
# manually add it to the individual test.py files that need it. If,
# instead, we had added it to some conftest.py, it would be applied to
# every tests in every file and we don't want that.
@pytest.fixture(autouse=True)
def apply_fake_pipeline_thread(fake_pipeline_thread):
pass
fake_device_id = "__fake_device_id__"
fake_module_id = "__fake_module_id__"
fake_hostname = "__fake_hostname__"

Просмотреть файл

@ -39,6 +39,17 @@ logging.basicConfig(level=logging.INFO)
this_module = sys.modules[__name__]
# This fixture makes it look like all test in this file tests are running
# inside the pipeline thread. Because this is an autouse fixture, we
# manually add it to the individual test.py files that need it. If,
# instead, we had added it to some conftest.py, it would be applied to
# every tests in every file and we don't want that.
@pytest.fixture(autouse=True)
def apply_fake_pipeline_thread(fake_pipeline_thread):
pass
fake_device_id = "__fake_device_id__"
fake_module_id = "__fake_module_id__"
fake_hostname = "__fake_hostname__"
@ -509,7 +520,6 @@ class TestIoTHubMQTTConverterWithEnableFeature(object):
def add_pipeline_root(stage, mocker):
root = pipeline_stages_base.PipelineRootStage()
mocker.spy(root, "handle_pipeline_event")
mocker.spy(root, "unhandled_error_handler")
stage.previous = root
stage.pipeline_root = root
@ -895,46 +905,61 @@ class TestIotHubMQTTConverterHandlePipelineEventTwinResponse(object):
assert new_event.response_body == fake_payload
@pytest.mark.it("Calls the unhandled exception handler if there is no previous stage")
def test_no_previous_stage(self, stage, fixup_stage_for_test, fake_event):
def test_no_previous_stage(
self, stage, fixup_stage_for_test, fake_event, unhandled_error_handler
):
stage.previous = None
stage.handle_pipeline_event(fake_event)
assert stage.pipeline_root.unhandled_error_handler.call_count == 1
assert isinstance(
stage.pipeline_root.unhandled_error_handler.call_args[0][0], NotImplementedError
)
assert unhandled_error_handler.call_count == 1
assert isinstance(unhandled_error_handler.call_args[0][0], NotImplementedError)
@pytest.mark.it(
"Calls the unhandled exception handler if the requet_id is missing from the topic name"
)
def test_invalid_topic_with_missing_request_id(
self, stage, fixup_stage_for_test, fake_event, fake_topic_name_with_missing_request_id
self,
stage,
fixup_stage_for_test,
fake_event,
fake_topic_name_with_missing_request_id,
unhandled_error_handler,
):
fake_event.topic = fake_topic_name_with_missing_request_id
stage.handle_pipeline_event(event=fake_event)
assert stage.pipeline_root.unhandled_error_handler.call_count == 1
assert isinstance(stage.pipeline_root.unhandled_error_handler.call_args[0][0], IndexError)
assert unhandled_error_handler.call_count == 1
assert isinstance(unhandled_error_handler.call_args[0][0], IndexError)
@pytest.mark.it(
"Calls the unhandled exception handler if the status code is missing from the topic name"
)
def test_invlid_topic_with_missing_status_code(
self, stage, fixup_stage_for_test, fake_event, fake_topic_name_with_missing_status_code
self,
stage,
fixup_stage_for_test,
fake_event,
fake_topic_name_with_missing_status_code,
unhandled_error_handler,
):
fake_event.topic = fake_topic_name_with_missing_status_code
stage.handle_pipeline_event(event=fake_event)
assert stage.pipeline_root.unhandled_error_handler.call_count == 1
assert isinstance(stage.pipeline_root.unhandled_error_handler.call_args[0][0], ValueError)
assert unhandled_error_handler.call_count == 1
assert isinstance(unhandled_error_handler.call_args[0][0], ValueError)
@pytest.mark.it(
"Calls the unhandled exception handler if the status code in the topic name is not numeric"
)
def test_invlid_topic_with_bad_status_code(
self, stage, fixup_stage_for_test, fake_event, fake_topic_name_with_bad_status_code
self,
stage,
fixup_stage_for_test,
fake_event,
fake_topic_name_with_bad_status_code,
unhandled_error_handler,
):
fake_event.topic = fake_topic_name_with_bad_status_code
stage.handle_pipeline_event(event=fake_event)
assert stage.pipeline_root.unhandled_error_handler.call_count == 1
assert isinstance(stage.pipeline_root.unhandled_error_handler.call_args[0][0], ValueError)
assert unhandled_error_handler.call_count == 1
assert isinstance(unhandled_error_handler.call_args[0][0], ValueError)
@pytest.mark.describe(
@ -984,30 +1009,34 @@ class TestIotHubMQTTConverterHandlePipelineEventTwinPatch(object):
assert new_event.patch == fake_patch
@pytest.mark.it("Calls the unhandled exception handler if there is no previous stage")
def test_no_previous_stage(self, stage, fixup_stage_for_test, fake_event):
def test_no_previous_stage(
self, stage, fixup_stage_for_test, fake_event, unhandled_error_handler
):
stage.previous = None
stage.handle_pipeline_event(fake_event)
assert stage.pipeline_root.unhandled_error_handler.call_count == 1
assert isinstance(
stage.pipeline_root.unhandled_error_handler.call_args[0][0], NotImplementedError
)
assert unhandled_error_handler.call_count == 1
assert isinstance(unhandled_error_handler.call_args[0][0], NotImplementedError)
@pytest.mark.it("Calls the unhandled exception handler if the payload is not a Bytes object")
def test_payload_not_bytes(self, stage, fixup_stage_for_test, fake_event, fake_patch_not_bytes):
def test_payload_not_bytes(
self, stage, fixup_stage_for_test, fake_event, fake_patch_not_bytes, unhandled_error_handler
):
fake_event.payload = fake_patch_not_bytes
stage.handle_pipeline_event(fake_event)
assert stage.pipeline_root.unhandled_error_handler.call_count == 1
assert unhandled_error_handler.call_count == 1
if not (
isinstance(stage.pipeline_root.unhandled_error_handler.call_args[0][0], AttributeError)
or isinstance(stage.pipeline_root.unhandled_error_handler.call_args[0][0], ValueError)
isinstance(unhandled_error_handler.call_args[0][0], AttributeError)
or isinstance(unhandled_error_handler.call_args[0][0], ValueError)
):
assert False
@pytest.mark.it(
"Calls the unhandled exception handler if the payload cannot be deserialized as a JSON object"
)
def test_payload_not_json(self, stage, fixup_stage_for_test, fake_event, fake_patch_not_json):
def test_payload_not_json(
self, stage, fixup_stage_for_test, fake_event, fake_patch_not_json, unhandled_error_handler
):
fake_event.payload = fake_patch_not_json
stage.handle_pipeline_event(fake_event)
assert stage.pipeline_root.unhandled_error_handler.call_count == 1
assert isinstance(stage.pipeline_root.unhandled_error_handler.call_args[0][0], ValueError)
assert unhandled_error_handler.call_count == 1
assert isinstance(unhandled_error_handler.call_args[0][0], ValueError)

Просмотреть файл

@ -4,4 +4,12 @@
# license information.
# --------------------------------------------------------------------------
from tests.common.pipeline.fixtures import callback, fake_exception, fake_base_exception, event
from tests.common.pipeline.fixtures import (
callback,
fake_exception,
fake_base_exception,
event,
fake_pipeline_thread,
fake_non_pipeline_thread,
unhandled_error_handler,
)

Просмотреть файл

@ -33,6 +33,13 @@ logging.basicConfig(level=logging.INFO)
this_module = sys.modules[__name__]
# Make it look like we're always running inside pipeline threads
@pytest.fixture(autouse=True)
def apply_fake_pipeline_thread(fake_pipeline_thread):
pass
fake_device_id = "elder_wand"
fake_registration_id = "registered_remembrall"
fake_provisioning_host = "hogwarts.com"

Просмотреть файл

@ -34,6 +34,17 @@ logging.basicConfig(level=logging.INFO)
this_module = sys.modules[__name__]
# This fixture makes it look like all test in this file tests are running
# inside the pipeline thread. Because this is an autouse fixture, we
# manually add it to the individual test.py files that need it. If,
# instead, we had added it to some conftest.py, it would be applied to
# every tests in every file and we don't want that.
@pytest.fixture(autouse=True)
def apply_fake_pipeline_thread(fake_pipeline_thread):
pass
fake_device_id = "elder_wand"
fake_registration_id = "registered_remembrall"
fake_provisioning_host = "hogwarts.com"

Просмотреть файл

@ -10,6 +10,7 @@ from azure.iot.device.common.models import X509
from azure.iot.device.provisioning.security.sk_security_client import SymmetricKeySecurityClient
from azure.iot.device.provisioning.security.x509_security_client import X509SecurityClient
from azure.iot.device.provisioning.pipeline.provisioning_pipeline import ProvisioningPipeline
from tests.common.pipeline import helpers
send_msg_qos = 1
@ -110,6 +111,10 @@ def mock_provisioning_pipeline(params_security_clients):
provisioning_pipeline.on_connected = MagicMock()
provisioning_pipeline.on_disconnected = MagicMock()
provisioning_pipeline.on_message_received = MagicMock()
helpers.add_mock_method_waiter(provisioning_pipeline, "on_connected")
helpers.add_mock_method_waiter(provisioning_pipeline, "on_disconnected")
helpers.add_mock_method_waiter(provisioning_pipeline._pipeline.transport, "publish")
yield provisioning_pipeline
provisioning_pipeline.disconnect()
@ -148,6 +153,7 @@ class TestConnect(object):
assert_for_client_x509(mock_mqtt_transport.connect.call_args[1]["client_certificate"])
mock_mqtt_transport.on_mqtt_connected()
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
@pytest.mark.it("After complete calls handler with new state")
def test_connected_state_handler_called_wth_new_state_once_provider_gets_connected(
@ -157,6 +163,7 @@ class TestConnect(object):
mock_provisioning_pipeline.connect()
mock_mqtt_transport.on_mqtt_connected()
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
mock_provisioning_pipeline.on_connected.assert_called_once_with("connected")
@ -169,6 +176,7 @@ class TestConnect(object):
mock_provisioning_pipeline.connect()
mock_provisioning_pipeline.connect()
mock_mqtt_transport.on_mqtt_connected()
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
assert mock_mqtt_transport.connect.call_count == 1
@ -187,6 +195,7 @@ class TestConnect(object):
mock_provisioning_pipeline.connect()
mock_mqtt_transport.on_mqtt_connected()
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
mock_mqtt_transport.reset_mock()
mock_provisioning_pipeline.on_connected.reset_mock()
@ -197,11 +206,13 @@ class TestConnect(object):
mock_provisioning_pipeline.connect()
mock_mqtt_transport.connect.assert_not_called()
mock_provisioning_pipeline.wait_for_on_connected_to_not_be_called()
mock_provisioning_pipeline.on_connected.assert_not_called()
mock_mqtt_transport.on_mqtt_published(0)
mock_mqtt_transport.connect.assert_not_called()
mock_provisioning_pipeline.wait_for_on_connected_to_not_be_called()
mock_provisioning_pipeline.on_connected.assert_not_called()
@ -216,6 +227,7 @@ class TestSendRegister(object):
mock_provisioning_pipeline.connect()
mock_mqtt_transport.on_mqtt_connected()
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
mock_provisioning_pipeline.send_request(
request_id=fake_request_id, request_payload=fake_mqtt_payload
)
@ -233,6 +245,7 @@ class TestSendRegister(object):
fake_request_id
)
mock_mqtt_transport.wait_for_publish_to_be_called()
assert mock_mqtt_transport.publish.call_count == 1
assert mock_mqtt_transport.publish.call_args[1]["topic"] == fake_publish_topic
assert mock_mqtt_transport.publish.call_args[1]["payload"] == fake_mqtt_payload
@ -259,11 +272,14 @@ class TestSendRegister(object):
assert_for_client_x509(mock_mqtt_transport.connect.call_args[1]["client_certificate"])
# verify that we're not connected yet and verify that we havent't published yet
mock_provisioning_pipeline.wait_for_on_connected_to_not_be_called()
mock_provisioning_pipeline.on_connected.assert_not_called()
mock_mqtt_transport.wait_for_publish_to_not_be_called()
mock_mqtt_transport.publish.assert_not_called()
# finish the connection
mock_mqtt_transport.on_mqtt_connected()
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
# verify that our connected callback was called and verify that we published the event
mock_provisioning_pipeline.on_connected.assert_called_once_with("connected")
@ -272,6 +288,7 @@ class TestSendRegister(object):
fake_request_id
)
mock_mqtt_transport.wait_for_publish_to_be_called()
assert mock_mqtt_transport.publish.call_count == 1
assert mock_mqtt_transport.publish.call_args[1]["topic"] == fake_publish_topic
assert mock_mqtt_transport.publish.call_args[1]["payload"] == fake_mqtt_payload
@ -299,17 +316,21 @@ class TestSendRegister(object):
)
# verify that we're not connected yet and verify that we havent't published yet
mock_provisioning_pipeline.wait_for_on_connected_to_not_be_called()
mock_provisioning_pipeline.on_connected.assert_not_called()
mock_mqtt_transport.wait_for_publish_to_not_be_called()
mock_mqtt_transport.publish.assert_not_called()
# finish the connection
mock_mqtt_transport.on_mqtt_connected()
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
# verify that our connected callback was called and verify that we published the event
mock_provisioning_pipeline.on_connected.assert_called_once_with("connected")
fake_publish_topic = "$dps/registrations/PUT/iotdps-register/?$rid={}".format(
fake_request_id
)
mock_mqtt_transport.wait_for_publish_to_be_called()
assert mock_mqtt_transport.publish.call_count == 1
assert mock_mqtt_transport.publish.call_args[1]["topic"] == fake_publish_topic
assert mock_mqtt_transport.publish.call_args[1]["payload"] == fake_mqtt_payload
@ -326,6 +347,7 @@ class TestSendRegister(object):
# connect
mock_provisioning_pipeline.connect()
mock_mqtt_transport.on_mqtt_connected()
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
# send an event
callback_1 = MagicMock()
@ -336,6 +358,7 @@ class TestSendRegister(object):
fake_publish_topic = "$dps/registrations/PUT/iotdps-register/?$rid={}".format(
fake_request_id_1
)
mock_mqtt_transport.wait_for_publish_to_be_called()
assert mock_mqtt_transport.publish.call_count == 1
assert mock_mqtt_transport.publish.call_args[1]["topic"] == fake_publish_topic
assert mock_mqtt_transport.publish.call_args[1]["payload"] == fake_msg_1
@ -349,6 +372,7 @@ class TestSendRegister(object):
# verify that we've called publish twice and verify that neither send_d2c_message
# has completed (because we didn't do anything here to complete it).
mock_mqtt_transport.wait_for_publish_to_be_called()
assert mock_mqtt_transport.publish.call_count == 2
callback_1.assert_not_called()
callback_2.assert_not_called()
@ -360,6 +384,7 @@ class TestSendRegister(object):
# connect
mock_provisioning_pipeline.connect()
mock_mqtt_transport.on_mqtt_connected()
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
# send an event
mock_provisioning_pipeline.send_request(
@ -383,6 +408,7 @@ class TestSendQuery(object):
mock_provisioning_pipeline.connect()
mock_mqtt_transport.on_mqtt_connected()
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
mock_provisioning_pipeline.send_request(
request_id=fake_request_id,
request_payload=fake_mqtt_payload,
@ -402,6 +428,7 @@ class TestSendQuery(object):
fake_request_id, fake_operation_id
)
mock_mqtt_transport.wait_for_publish_to_be_called()
assert mock_mqtt_transport.publish.call_count == 1
assert mock_mqtt_transport.publish.call_args[1]["topic"] == fake_publish_topic
assert mock_mqtt_transport.publish.call_args[1]["payload"] == fake_mqtt_payload
@ -416,6 +443,7 @@ class TestDisconnect(object):
mock_provisioning_pipeline.connect()
mock_mqtt_transport.on_mqtt_connected()
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
mock_provisioning_pipeline.disconnect()
mock_mqtt_transport.disconnect.assert_called_once_with()
@ -434,10 +462,12 @@ class TestDisconnect(object):
mock_provisioning_pipeline.connect()
mock_mqtt_transport.on_mqtt_connected()
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
mock_provisioning_pipeline.disconnect()
mock_mqtt_transport.on_mqtt_disconnected()
mock_provisioning_pipeline.wait_for_on_disconnected_to_be_called()
mock_provisioning_pipeline.on_disconnected.assert_called_once_with("disconnected")
@ -450,6 +480,7 @@ class TestEnable(object):
mock_provisioning_pipeline.connect()
mock_mqtt_transport.on_mqtt_connected()
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
mock_provisioning_pipeline.enable_responses()
assert mock_mqtt_transport.subscribe.call_count == 1
@ -465,6 +496,7 @@ class TestDisable(object):
mock_provisioning_pipeline.connect()
mock_mqtt_transport.on_mqtt_connected()
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
mock_provisioning_pipeline.disable_responses(None)
assert mock_mqtt_transport.unsubscribe.call_count == 1

Просмотреть файл

@ -50,6 +50,7 @@ class TestClientCreate(object):
patch_set_sym_client = mocker.patch.object(
pipeline_ops_provisioning, "SetSymmetricKeySecurityClientOperation"
)
patch_set_sym_client.callback = mocker.MagicMock()
client = ProvisioningDeviceClient.create_from_symmetric_key(
fake_provisioning_host, fake_symmetric_key, fake_registration_id, fake_id_scope
)

Просмотреть файл

@ -59,7 +59,7 @@ jobs:
inputs:
testResultsFiles: '**/*-test-results.xml'
testRunTitle: 'Python $(python.version)'
condition: always()
condition: always()
- task: PublishCodeCoverageResults@1
inputs: