marshal all calls into a single pipeline thread (#133)
* marshal all calls into a single pipeline thread * incorporate review feedback * code review feedback and tests * flake8 * PR feedback * Automatic LF/CRLF handling regardless of user's .gitconfig settings
This commit is contained in:
Родитель
cb0c5f7e1e
Коммит
1206459730
|
@ -0,0 +1,25 @@
|
|||
# Default behavior: if Git thinks a file is text (as opposed to binary), it
|
||||
# will normalize line endings to LF in the repository, but convert to your
|
||||
# platform's native line endings on checkout (e.g., CRLF for Windows).
|
||||
* text=auto
|
||||
|
||||
# Explicitly declare text files you want to always be normalized and converted
|
||||
# to native line endings on checkout. E.g.,
|
||||
*.md text=auto
|
||||
*.json text=auto
|
||||
*.ps1 text=auto
|
||||
|
||||
# Declare files that will always have LF line endings on checkout. E.g.,
|
||||
*.sh text eol=lf
|
||||
|
||||
# Denote all files that should not have line endings normalized, should not be
|
||||
# merged, and should not show in a textual diff.
|
||||
*.docm binary
|
||||
*.docx binary
|
||||
*.ico binary
|
||||
*.lib binary
|
||||
*.png binary
|
||||
*.pptx binary
|
||||
*.snk binary
|
||||
*.vsdx binary
|
||||
*.xps binary
|
|
@ -56,6 +56,7 @@ class MQTTTransport(object):
|
|||
self._mqtt_client = mqtt.Client(
|
||||
client_id=self._client_id, clean_session=False, protocol=mqtt.MQTTv311
|
||||
)
|
||||
self._mqtt_client.enable_logger(logging.getLogger("paho"))
|
||||
|
||||
def on_connect(client, userdata, flags, rc):
|
||||
logger.info("connected with result code: {}".format(rc))
|
||||
|
|
|
@ -6,12 +6,15 @@
|
|||
|
||||
import logging
|
||||
import sys
|
||||
from . import pipeline_thread
|
||||
from azure.iot.device.common import unhandled_exceptions
|
||||
|
||||
from six.moves import queue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def run_ops_in_serial(stage, *args, **kwargs):
|
||||
"""
|
||||
Run the operations passed in *args in a serial manner, such that each operation waits for the
|
||||
|
@ -57,9 +60,11 @@ def run_ops_in_serial(stage, *args, **kwargs):
|
|||
if not callback:
|
||||
raise TypeError("callback is required")
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def on_last_op_done(last_op):
|
||||
if finally_op:
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def on_finally_done(finally_op):
|
||||
logger.info(
|
||||
"{}({}):run_ops_serial: finally_op done.".format(stage.name, finally_op.name)
|
||||
|
@ -85,7 +90,7 @@ def run_ops_in_serial(stage, *args, **kwargs):
|
|||
),
|
||||
exc_info=e,
|
||||
)
|
||||
stage.pipeline_root.unhandled_error_handler(e)
|
||||
unhandled_exceptions.exception_caught_in_background_thread(e)
|
||||
|
||||
finally_op.callback = on_finally_done
|
||||
logger.info(
|
||||
|
@ -107,8 +112,9 @@ def run_ops_in_serial(stage, *args, **kwargs):
|
|||
),
|
||||
exc_info=e,
|
||||
)
|
||||
stage.pipeline_root.unhandled_error_handler(e)
|
||||
unhandled_exceptions.exception_caught_in_background_thread(e)
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def on_op_done(completed_op):
|
||||
logger.info(
|
||||
"{}({}):run_ops_serial: completed. {} items left".format(
|
||||
|
@ -146,6 +152,7 @@ def run_ops_in_serial(stage, *args, **kwargs):
|
|||
pass_op_to_next_stage(stage, first_op)
|
||||
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def delegate_to_different_op(stage, original_op, new_op):
|
||||
"""
|
||||
Continue an operation using a new operation. This means that the new operation
|
||||
|
@ -182,6 +189,7 @@ def delegate_to_different_op(stage, original_op, new_op):
|
|||
|
||||
logger.info("{}({}): continuing with {} op".format(stage.name, original_op.name, new_op.name))
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def new_op_complete(op):
|
||||
logger.info(
|
||||
"{}({}): completing with result from {}".format(
|
||||
|
@ -195,6 +203,7 @@ def delegate_to_different_op(stage, original_op, new_op):
|
|||
pass_op_to_next_stage(stage, new_op)
|
||||
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def pass_op_to_next_stage(stage, op):
|
||||
"""
|
||||
Helper function to continue a given operation by passing it to the next stage
|
||||
|
@ -220,6 +229,7 @@ def pass_op_to_next_stage(stage, op):
|
|||
stage.next.run_op(op)
|
||||
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def complete_op(stage, op):
|
||||
"""
|
||||
Helper function to complete an operation by calling its callback function thus
|
||||
|
@ -240,9 +250,10 @@ def complete_op(stage, op):
|
|||
),
|
||||
exc_info=e,
|
||||
)
|
||||
stage.pipeline_root.unhandled_error_handler(e)
|
||||
unhandled_exceptions.exception_caught_in_background_thread(e)
|
||||
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def pass_event_to_previous_stage(stage, event):
|
||||
"""
|
||||
Helper function to pass an event to the previous stage of the pipeline. This is the default
|
||||
|
@ -259,4 +270,4 @@ def pass_event_to_previous_stage(stage, event):
|
|||
error = NotImplementedError(
|
||||
"{} unhandled at {} stage with no previous stage".format(event.name, stage.name)
|
||||
)
|
||||
stage.pipeline_root.unhandled_error_handler(error)
|
||||
unhandled_exceptions.exception_caught_in_background_thread(error)
|
||||
|
|
|
@ -12,6 +12,8 @@ from six.moves import queue
|
|||
from . import pipeline_events_base
|
||||
from . import pipeline_ops_base
|
||||
from . import operation_flow
|
||||
from . import pipeline_thread
|
||||
from azure.iot.device.common import unhandled_exceptions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -75,6 +77,7 @@ class PipelineStage(object):
|
|||
self.previous = None
|
||||
self.pipeline_root = None
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def run_op(self, op):
|
||||
"""
|
||||
Run the given operation. This is the public function that outside callers would call to run an
|
||||
|
@ -111,6 +114,7 @@ class PipelineStage(object):
|
|||
"""
|
||||
pass
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def handle_pipeline_event(self, event):
|
||||
"""
|
||||
Handle a pipeline event that arrives from the stage below this stage. Derived
|
||||
|
@ -123,11 +127,9 @@ class PipelineStage(object):
|
|||
try:
|
||||
self._handle_pipeline_event(event)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
msg="Error in %s._handle_pipeline_event call".format(self.name), exc_info=e
|
||||
)
|
||||
self.pipeline_root.unhandled_error_handler(e)
|
||||
unhandled_exceptions.exception_caught_in_background_thread(e)
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def _handle_pipeline_event(self, event):
|
||||
"""
|
||||
Handle a pipeline event that arrives from the stage below this stage. This
|
||||
|
@ -138,6 +140,7 @@ class PipelineStage(object):
|
|||
"""
|
||||
operation_flow.pass_event_to_previous_stage(self, event)
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def on_connected(self):
|
||||
"""
|
||||
Called by lower layers when the protocol client connects
|
||||
|
@ -145,6 +148,7 @@ class PipelineStage(object):
|
|||
if self.previous:
|
||||
self.previous.on_connected()
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def on_disconnected(self):
|
||||
"""
|
||||
Called by lower layers when the protocol client disconnects
|
||||
|
@ -170,6 +174,11 @@ class PipelineRootStage(PipelineStage):
|
|||
super(PipelineRootStage, self).__init__()
|
||||
self.on_pipeline_event = None
|
||||
|
||||
def run_op(self, op):
|
||||
op.callback = pipeline_thread.invoke_on_callback_thread_nowait(op.callback)
|
||||
pipeline_thread.invoke_on_pipeline_thread(super(PipelineRootStage, self).run_op)(op)
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def _run_op(self, op):
|
||||
"""
|
||||
run the operation. At the root, the only thing to do is to pass the operation
|
||||
|
@ -196,16 +205,7 @@ class PipelineRootStage(PipelineStage):
|
|||
new_next_stage.pipeline_root = self
|
||||
return self
|
||||
|
||||
def unhandled_error_handler(self, error):
|
||||
"""
|
||||
Handler for errors that happen which cannot be tied to a specific operation.
|
||||
This is still a tentative implimentation and masy be replaced by
|
||||
some other mechanism as details on behavior are finalized.
|
||||
"""
|
||||
# TODO: decide how to pass this error to the app
|
||||
# TODO: if there's an error in the app handler, print it and exit
|
||||
pass
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def _handle_pipeline_event(self, event):
|
||||
"""
|
||||
Override of the PipelineEvent handler. Because this is the root of the pipeline,
|
||||
|
@ -241,6 +241,7 @@ class EnsureConnectionStage(PipelineStage):
|
|||
self.queue = queue.Queue()
|
||||
self.blocked = False
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def _run_op(self, op):
|
||||
# If this stage is currently blocked (because we're waiting for a connection
|
||||
# to complete, we queue up all operations until after the connect completes.
|
||||
|
@ -290,6 +291,7 @@ class EnsureConnectionStage(PipelineStage):
|
|||
else:
|
||||
operation_flow.pass_op_to_next_stage(self, op)
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def _block(self, op):
|
||||
"""
|
||||
block this stage while we're waiting for the connection to complete.
|
||||
|
@ -297,6 +299,7 @@ class EnsureConnectionStage(PipelineStage):
|
|||
logger.info("{}({}): enabling block".format(self.name, op.name))
|
||||
self.blocked = True
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def _unblock(self, op, error):
|
||||
"""
|
||||
Unblock this stage after the connection is complete. This also means
|
||||
|
@ -329,6 +332,7 @@ class EnsureConnectionStage(PipelineStage):
|
|||
)
|
||||
self.run_op(op_to_release)
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def _do_connect(self, op):
|
||||
"""
|
||||
Start connecting the protocol client in response to some operation (which may or may not be a Connect operation)
|
||||
|
@ -343,6 +347,7 @@ class EnsureConnectionStage(PipelineStage):
|
|||
self.queue.put_nowait(op)
|
||||
|
||||
# function that gets called after we're connected.
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def on_connected(op_connect):
|
||||
logger.info("{}({}): connection is complete".format(self.name, op.name))
|
||||
# if we're connecting because some layer above us asked us to connect, we complete that operation
|
||||
|
@ -361,10 +366,12 @@ class EnsureConnectionStage(PipelineStage):
|
|||
self, pipeline_ops_base.ConnectOperation(callback=on_connected)
|
||||
)
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def on_connected(self):
|
||||
self.connected = True
|
||||
PipelineStage.on_connected(self)
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def on_disconnected(self):
|
||||
self.connected = False
|
||||
PipelineStage.on_disconnected(self)
|
||||
|
@ -381,6 +388,7 @@ class CoordinateRequestAndResponseStage(PipelineStage):
|
|||
super(CoordinateRequestAndResponseStage, self).__init__()
|
||||
self.pending_responses = {}
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def _run_op(self, op):
|
||||
if isinstance(op, pipeline_ops_base.SendIotRequestAndWaitForResponseOperation):
|
||||
# Convert SendIotRequestAndWaitForResponseOperation operation into a SendIotRequestOperation operation
|
||||
|
@ -390,6 +398,7 @@ class CoordinateRequestAndResponseStage(PipelineStage):
|
|||
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def on_send_request_done(send_request_op):
|
||||
logger.info(
|
||||
"{}({}): Finished sending {} request to {} resource {}".format(
|
||||
|
@ -433,6 +442,7 @@ class CoordinateRequestAndResponseStage(PipelineStage):
|
|||
else:
|
||||
operation_flow.pass_op_to_next_stage(self, op)
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def _handle_pipeline_event(self, event):
|
||||
if isinstance(event, pipeline_events_base.IotResponseEvent):
|
||||
# match IotResponseEvent events to the saved dictionary of SendIotRequestAndWaitForResponseOperation
|
||||
|
|
|
@ -11,6 +11,7 @@ from . import (
|
|||
pipeline_ops_mqtt,
|
||||
pipeline_events_mqtt,
|
||||
operation_flow,
|
||||
pipeline_thread,
|
||||
)
|
||||
from azure.iot.device.common.mqtt_transport import MQTTTransport
|
||||
|
||||
|
@ -24,6 +25,7 @@ class MQTTClientStage(PipelineStage):
|
|||
is not in the MQTT group of operations, but can only be run at the protocol level.
|
||||
"""
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def _run_op(self, op):
|
||||
if isinstance(op, pipeline_ops_mqtt.SetMQTTConnectionArgsOperation):
|
||||
# pipeline_ops_mqtt.SetMQTTConnectionArgsOperation is where we create our MQTTTransport object and set
|
||||
|
@ -62,6 +64,7 @@ class MQTTClientStage(PipelineStage):
|
|||
elif isinstance(op, pipeline_ops_base.ConnectOperation):
|
||||
logger.info("{}({}): conneting".format(self.name, op.name))
|
||||
|
||||
@pipeline_thread.invoke_on_pipeline_thread_nowait
|
||||
def on_connected():
|
||||
logger.info("{}({}): on_connected. completing op.".format(self.name, op.name))
|
||||
self.transport.on_mqtt_connected = self.on_connected
|
||||
|
@ -106,6 +109,7 @@ class MQTTClientStage(PipelineStage):
|
|||
elif isinstance(op, pipeline_ops_base.ReconnectOperation):
|
||||
logger.info("{}({}): reconnecting".format(self.name, op.name))
|
||||
|
||||
@pipeline_thread.invoke_on_pipeline_thread_nowait
|
||||
def on_connected():
|
||||
logger.info("{}({}): on_connected. completing op.".format(self.name, op.name))
|
||||
self.transport.on_mqtt_connected = self.on_connected
|
||||
|
@ -123,6 +127,7 @@ class MQTTClientStage(PipelineStage):
|
|||
elif isinstance(op, pipeline_ops_base.DisconnectOperation):
|
||||
logger.info("{}({}): disconnecting".format(self.name, op.name))
|
||||
|
||||
@pipeline_thread.invoke_on_pipeline_thread_nowait
|
||||
def on_disconnected():
|
||||
logger.info("{}({}): on_disconnected. completing op.".format(self.name, op.name))
|
||||
self.transport.on_mqtt_disconnected = self.on_disconnected
|
||||
|
@ -140,6 +145,7 @@ class MQTTClientStage(PipelineStage):
|
|||
elif isinstance(op, pipeline_ops_mqtt.MQTTPublishOperation):
|
||||
logger.info("{}({}): publishing on {}".format(self.name, op.name, op.topic))
|
||||
|
||||
@pipeline_thread.invoke_on_pipeline_thread_nowait
|
||||
def on_published():
|
||||
logger.info("{}({}): PUBACK received. completing op.".format(self.name, op.name))
|
||||
operation_flow.complete_op(self, op)
|
||||
|
@ -149,6 +155,7 @@ class MQTTClientStage(PipelineStage):
|
|||
elif isinstance(op, pipeline_ops_mqtt.MQTTSubscribeOperation):
|
||||
logger.info("{}({}): subscribing to {}".format(self.name, op.name, op.topic))
|
||||
|
||||
@pipeline_thread.invoke_on_pipeline_thread_nowait
|
||||
def on_subscribed():
|
||||
logger.info("{}({}): SUBACK received. completing op.".format(self.name, op.name))
|
||||
operation_flow.complete_op(self, op)
|
||||
|
@ -158,6 +165,7 @@ class MQTTClientStage(PipelineStage):
|
|||
elif isinstance(op, pipeline_ops_mqtt.MQTTUnsubscribeOperation):
|
||||
logger.info("{}({}): unsubscribing from {}".format(self.name, op.name, op.topic))
|
||||
|
||||
@pipeline_thread.invoke_on_pipeline_thread_nowait
|
||||
def on_unsubscribed():
|
||||
logger.info("{}({}): UNSUBACK received. completing op.".format(self.name, op.name))
|
||||
operation_flow.complete_op(self, op)
|
||||
|
@ -167,6 +175,7 @@ class MQTTClientStage(PipelineStage):
|
|||
else:
|
||||
operation_flow.pass_op_to_next_stage(self, op)
|
||||
|
||||
@pipeline_thread.invoke_on_pipeline_thread_nowait
|
||||
def _on_message_received(self, topic, payload):
|
||||
"""
|
||||
Handler that gets called by the protocol library when an incoming message arrives.
|
||||
|
@ -176,3 +185,11 @@ class MQTTClientStage(PipelineStage):
|
|||
stage=self,
|
||||
event=pipeline_events_mqtt.IncomingMQTTMessageEvent(topic=topic, payload=payload),
|
||||
)
|
||||
|
||||
@pipeline_thread.invoke_on_pipeline_thread_nowait
|
||||
def on_connected(self):
|
||||
super(MQTTClientStage, self).on_connected()
|
||||
|
||||
@pipeline_thread.invoke_on_pipeline_thread_nowait
|
||||
def on_disconnected(self):
|
||||
super(MQTTClientStage, self).on_disconnected()
|
||||
|
|
|
@ -0,0 +1,196 @@
|
|||
# --------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
import functools
|
||||
import logging
|
||||
import threading
|
||||
import traceback
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from azure.iot.device.common import unhandled_exceptions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
"""
|
||||
This module contains decorators that are used to marshal code into pipeline and
|
||||
callback threads and to assert that code is being called in the correct thread.
|
||||
|
||||
The intention of these decorators is to ensure the following:
|
||||
|
||||
1. All pipeline functions execute in a single thread, known as the "pipeline
|
||||
thread". The `invoke_on_pipeline_thread` and `invoke_on_pipeline_thread_nowait`
|
||||
decorators cause the decorated function to run on the pipeline thread.
|
||||
|
||||
2. If the pipeline thread is busy running a different function, the invoke
|
||||
decorators will wait until that function is complete before invoking another
|
||||
function on that thread.
|
||||
|
||||
3. There is a different thread which is used for callbacks into user code, known
|
||||
as the the "callback thread". This is not meant for callbacks into pipeline
|
||||
code. Those callbacks should still execute on the pipeline thread. The
|
||||
`invoke_on_callback_thread_nowait` decorator is used to ensure that callbacks
|
||||
execute on the callback thread.
|
||||
|
||||
4. Decorators which cause thread switches are used only when necessary. The
|
||||
pipeline thread is only entered in places where we know that external code is
|
||||
calling into the pipeline (such as a client API call or a callback from a
|
||||
third-party library). Likewise, the callback thread is only entered in places
|
||||
where we know that the pipeline is calling back into client code.
|
||||
|
||||
5. Exceptions raised from the pipeline thread are still able to be caught by
|
||||
the function which entered the pipeline thread.
|
||||
|
||||
5. Calls into the pipeline thread can either block or not block. Blocking is used
|
||||
for cases where the caller needs a return value from the pipeline or is
|
||||
expecting to handle any errors raised from the pipeline thread. Blocking is
|
||||
not used when the code calling into the pipeline is not waiting for a response
|
||||
and is not expecting to handle any exceptions, such as protocol library
|
||||
handlers which call into the pipeline to deliver protocol messages.
|
||||
|
||||
6. Calls into the callback thread could theoretically block, but we currently
|
||||
only have decorators which enter the callback thread without blocking. This
|
||||
is done to ensure that client code does not execute on the pipeline thread and
|
||||
also to ensure that the pipline thread is not blocked while waiting for client
|
||||
code to execute.
|
||||
|
||||
These decorators use concurrent.futures.Future and the ThreadPoolExecutor because:
|
||||
|
||||
1. The thread pooling with a pool size of 1 gives us a single thread to run all
|
||||
pipeline operations and a different (single) thread to run all callbacks. If
|
||||
the code attempts to run a second pipeline operation (or callback) while a
|
||||
different one is running, the ThreadPoolExecutor will queue the code until the
|
||||
first call is completed.
|
||||
|
||||
2. The concurent.futures.Future object properly handles both Exception and
|
||||
BaseException errors, re-raising them when the Future.result method is called.
|
||||
threading.Thread.get() was not an option because it doesn't re-raise
|
||||
BaseException errors when Thread.get is called.
|
||||
|
||||
3. concurrent.futures is available as a backport to 2.7.
|
||||
|
||||
"""
|
||||
|
||||
_executors = {}
|
||||
|
||||
|
||||
def _get_named_executor(thread_name):
|
||||
"""
|
||||
Get a ThreadPoolExecutor object with the given name. If no such executor exists,
|
||||
this function will create on with a single worker and assign it to the provided
|
||||
name.
|
||||
"""
|
||||
global _executors
|
||||
if thread_name not in _executors:
|
||||
logger.info("Creating {} executor".format(thread_name))
|
||||
_executors[thread_name] = ThreadPoolExecutor(max_workers=1)
|
||||
return _executors[thread_name]
|
||||
|
||||
|
||||
def _invoke_on_executor_thread(func, thread_name, block=True):
|
||||
"""
|
||||
Return wrapper to run the function on a given thread. If block==False,
|
||||
the call returns immediately without waiting for the decorated function to complete.
|
||||
If block==True, the call waits for the decorated function to complete before returning.
|
||||
"""
|
||||
|
||||
# Mocks on py27 don't have a __name__ attribute. Use str() if you can't use __name__
|
||||
try:
|
||||
function_name = func.__name__
|
||||
function_has_name = True
|
||||
except AttributeError:
|
||||
function_name = str(func)
|
||||
function_has_name = False
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
if threading.current_thread().name is not thread_name:
|
||||
logger.info("Starting {} in {} thread".format(function_name, thread_name))
|
||||
|
||||
def thread_proc():
|
||||
threading.current_thread().name = thread_name
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
if not block:
|
||||
unhandled_exceptions.exception_caught_in_background_thread(e)
|
||||
except BaseException:
|
||||
if not block:
|
||||
logger.error("Unhandled exception in background thread")
|
||||
logger.error(
|
||||
"This may cause the background thread to abort and may result in system instability."
|
||||
)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
# TODO: add a timeout here and throw exception on failure
|
||||
future = _get_named_executor(thread_name).submit(thread_proc)
|
||||
if block:
|
||||
return future.result()
|
||||
else:
|
||||
return future
|
||||
else:
|
||||
logger.debug("Already in {} thread for {}".format(thread_name, function_name))
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Silly hack: On 2.7, we can't use @functools.wraps on callables don't have a __name__ attribute
|
||||
# attribute(like MagicMock object), so we only do it when we have a name. functools.update_wrapper
|
||||
# below is the same as using the @functools.wraps(func) decorator on the wrapper function above.
|
||||
if function_has_name:
|
||||
return functools.update_wrapper(wrapped=func, wrapper=wrapper)
|
||||
else:
|
||||
wrapper.__wrapped__ = func # needed by tests
|
||||
return wrapper
|
||||
|
||||
|
||||
def invoke_on_pipeline_thread(func):
|
||||
"""
|
||||
Run the decorated function on the pipeline thread.
|
||||
"""
|
||||
return _invoke_on_executor_thread(func=func, thread_name="pipeline")
|
||||
|
||||
|
||||
def invoke_on_pipeline_thread_nowait(func):
|
||||
"""
|
||||
Run the decorated function on the pipeline thread, but don't wait for it to complete
|
||||
"""
|
||||
return _invoke_on_executor_thread(func=func, thread_name="pipeline", block=False)
|
||||
|
||||
|
||||
def invoke_on_callback_thread_nowait(func):
|
||||
"""
|
||||
Run the decorated function on the callback thread, but don't wait for it to complete
|
||||
"""
|
||||
return _invoke_on_executor_thread(func=func, thread_name="callback", block=False)
|
||||
|
||||
|
||||
def _assert_executor_thread(func, thread_name):
|
||||
"""
|
||||
Decorator which asserts that the given function only gets called inside the given
|
||||
thread.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
|
||||
assert (
|
||||
threading.current_thread().name == thread_name
|
||||
), """
|
||||
Function {function_name} is not running inside {thread_name} thread.
|
||||
It should be. You should use invoke_on_{thread_name}_thread(_nowait) to enter the
|
||||
{thread_name} thread before calling this function. If you're hitting this from
|
||||
inside a test function, you may need to add the fake_pipeline_thread fixture to
|
||||
your test. (grep for apply_fake_pipeline_thread) """.format(
|
||||
function_name=func.__name__, thread_name=thread_name
|
||||
)
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def runs_on_pipeline_thread(func):
|
||||
"""
|
||||
Decorator which marks a function as only running inside the pipeline thread.
|
||||
"""
|
||||
return _assert_executor_thread(func=func, thread_name="pipeline")
|
|
@ -0,0 +1,27 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def exception_caught_in_background_thread(e):
|
||||
"""
|
||||
Function which handled exceptions that are caught in background thread. This is
|
||||
typically called from the callback thread inside the pipeline. These exceptions
|
||||
need special handling because callback functions are typically called inside a
|
||||
non-application thread in response to non-user-initiated actions, so there's
|
||||
nobody else to catch them.
|
||||
|
||||
This function gets called from inside an arbitrary thread context, so code that
|
||||
runs from this function should be limited to the bare minumum.
|
||||
|
||||
:param Error e: Exception object raised from inside a background thread
|
||||
"""
|
||||
|
||||
# @FUTURE: We should add a mechanism which allows applications to receive these
|
||||
# exceptions so they can respond accordingly
|
||||
logger.error(msg="Exception caught in background thread. Unable to handle.", exc_info=e)
|
|
@ -6,7 +6,12 @@
|
|||
|
||||
import json
|
||||
import logging
|
||||
from azure.iot.device.common.pipeline import pipeline_ops_base, PipelineStage, operation_flow
|
||||
from azure.iot.device.common.pipeline import (
|
||||
pipeline_ops_base,
|
||||
PipelineStage,
|
||||
operation_flow,
|
||||
pipeline_thread,
|
||||
)
|
||||
from . import pipeline_ops_iothub
|
||||
from . import constant
|
||||
|
||||
|
@ -27,6 +32,7 @@ class UseAuthProviderStage(PipelineStage):
|
|||
All other operations are passed down.
|
||||
"""
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def _run_op(self, op):
|
||||
def pipeline_ops_done(completed_op):
|
||||
op.error = completed_op.error
|
||||
|
@ -78,6 +84,7 @@ class HandleTwinOperationsStage(PipelineStage):
|
|||
protocol-specific receive event into an IotResponseEvent event.
|
||||
"""
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def _run_op(self, op):
|
||||
def map_twin_error(original_op, twin_op):
|
||||
if twin_op.error:
|
||||
|
|
|
@ -13,6 +13,7 @@ from azure.iot.device.common.pipeline import (
|
|||
pipeline_events_mqtt,
|
||||
PipelineStage,
|
||||
operation_flow,
|
||||
pipeline_thread,
|
||||
)
|
||||
from azure.iot.device.iothub.models import Message, MethodRequest
|
||||
from . import constant, pipeline_ops_iothub, pipeline_events_iothub, mqtt_topic_iothub
|
||||
|
@ -30,6 +31,7 @@ class IoTHubMQTTConverterStage(PipelineStage):
|
|||
super(IoTHubMQTTConverterStage, self).__init__()
|
||||
self.feature_to_topic = {}
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def _run_op(self, op):
|
||||
|
||||
if isinstance(op, pipeline_ops_iothub.SetAuthProviderArgsOperation):
|
||||
|
@ -127,6 +129,7 @@ class IoTHubMQTTConverterStage(PipelineStage):
|
|||
# All other operations get passed down
|
||||
operation_flow.pass_op_to_next_stage(self, op)
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def _set_topic_names(self, device_id, module_id):
|
||||
"""
|
||||
Build topic names based on the device_id and module_id passed.
|
||||
|
@ -144,6 +147,7 @@ class IoTHubMQTTConverterStage(PipelineStage):
|
|||
constant.TWIN_PATCHES: (mqtt_topic_iothub.get_twin_patch_topic_for_subscribe()),
|
||||
}
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def _handle_pipeline_event(self, event):
|
||||
"""
|
||||
Pipeline Event handler function to convert incoming MQTT messages into the appropriate IoTHub
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from azure.iot.device.common.pipeline import pipeline_ops_base, operation_flow
|
||||
from azure.iot.device.common.pipeline import pipeline_ops_base, operation_flow, pipeline_thread
|
||||
from azure.iot.device.common.pipeline.pipeline_stages_base import PipelineStage
|
||||
from . import pipeline_ops_provisioning
|
||||
|
||||
|
@ -28,6 +28,7 @@ class UseSecurityClientStage(PipelineStage):
|
|||
All other operations are passed down.
|
||||
"""
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def _run_op(self, op):
|
||||
def pipeline_ops_done(completed_op):
|
||||
op.error = completed_op.error
|
||||
|
|
|
@ -11,6 +11,7 @@ from azure.iot.device.common.pipeline import (
|
|||
pipeline_ops_mqtt,
|
||||
pipeline_events_mqtt,
|
||||
operation_flow,
|
||||
pipeline_thread,
|
||||
)
|
||||
from azure.iot.device.common.pipeline.pipeline_stages_base import PipelineStage
|
||||
from azure.iot.device.provisioning.pipeline import constant, mqtt_topic
|
||||
|
@ -32,6 +33,7 @@ class ProvisioningMQTTConverterStage(PipelineStage):
|
|||
super(ProvisioningMQTTConverterStage, self).__init__()
|
||||
self.action_to_topic = {}
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def _run_op(self, op):
|
||||
|
||||
if isinstance(
|
||||
|
@ -103,6 +105,7 @@ class ProvisioningMQTTConverterStage(PipelineStage):
|
|||
# All other operations get passed down
|
||||
operation_flow.pass_op_to_next_stage(self, op)
|
||||
|
||||
@pipeline_thread.runs_on_pipeline_thread
|
||||
def _handle_pipeline_event(self, event):
|
||||
"""
|
||||
Pipeline Event handler function to convert incoming MQTT messages into the appropriate DPS
|
||||
|
|
|
@ -63,6 +63,7 @@ setup(
|
|||
"requests>=2.20.0,<3.0.0",
|
||||
"requests-unixsocket>=0.1.5,<1.0.0",
|
||||
"janus>=0.4.0,<1.0.0;python_version>='3.5'",
|
||||
"futures;python_version == '2.7'",
|
||||
],
|
||||
extras_require={":python_version<'3.0'": ["azure-iot-nspkg>=1.0.1"]},
|
||||
python_requires=">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3*, <4",
|
||||
|
|
|
@ -14,4 +14,7 @@ from tests.common.pipeline.fixtures import (
|
|||
op3,
|
||||
finally_op,
|
||||
new_op,
|
||||
fake_pipeline_thread,
|
||||
fake_non_pipeline_thread,
|
||||
unhandled_error_handler,
|
||||
)
|
||||
|
|
|
@ -4,8 +4,14 @@
|
|||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
import pytest
|
||||
import threading
|
||||
from tests.common.pipeline import helpers
|
||||
from azure.iot.device.common.pipeline import pipeline_events_base, pipeline_ops_base
|
||||
from azure.iot.device.common import unhandled_exceptions
|
||||
from azure.iot.device.common.pipeline import (
|
||||
pipeline_events_base,
|
||||
pipeline_ops_base,
|
||||
pipeline_thread,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -39,35 +45,67 @@ class FakeOperation(pipeline_ops_base.PipelineOperation):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def op():
|
||||
op = FakeOperation()
|
||||
def op(callback):
|
||||
op = FakeOperation(callback=callback)
|
||||
op.name = "op"
|
||||
return op
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def op2():
|
||||
op = FakeOperation()
|
||||
def op2(callback):
|
||||
op = FakeOperation(callback=callback)
|
||||
op.name = "op2"
|
||||
return op
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def op3():
|
||||
op = FakeOperation()
|
||||
def op3(callback):
|
||||
op = FakeOperation(callback=callback)
|
||||
op.name = "op3"
|
||||
return op
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def finally_op():
|
||||
op = FakeOperation()
|
||||
def finally_op(callback):
|
||||
op = FakeOperation(callback=callback)
|
||||
op.name = "finally_op"
|
||||
return op
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def new_op():
|
||||
op = FakeOperation()
|
||||
def new_op(callback):
|
||||
op = FakeOperation(callback=callback)
|
||||
op.name = "new_op"
|
||||
return op
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_pipeline_thread():
|
||||
"""
|
||||
This fixture mocks out the thread name so that the pipeline decorators
|
||||
use to assert that you are in a pipeline thread.
|
||||
"""
|
||||
this_thread = threading.current_thread()
|
||||
old_name = this_thread.name
|
||||
|
||||
this_thread.name = "pipeline"
|
||||
yield
|
||||
this_thread.name = old_name
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_non_pipeline_thread():
|
||||
"""
|
||||
This fixture sets thread name to something other than "pipeline" to force asserts
|
||||
"""
|
||||
this_thread = threading.current_thread()
|
||||
old_name = this_thread.name
|
||||
|
||||
this_thread.name = "not pipeline"
|
||||
yield
|
||||
this_thread.name = old_name
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unhandled_error_handler(mocker):
|
||||
return mocker.patch.object(unhandled_exceptions, "exception_caught_in_background_thread")
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
import inspect
|
||||
import pytest
|
||||
import functools
|
||||
from threading import Event
|
||||
from azure.iot.device.common.pipeline import (
|
||||
pipeline_events_base,
|
||||
pipeline_ops_base,
|
||||
|
@ -95,6 +96,13 @@ def make_mock_stage(mocker, stage_to_make):
|
|||
def assert_callback_succeeded(op, callback=None):
|
||||
if not callback:
|
||||
callback = op.callback
|
||||
try:
|
||||
# if the callback has a __wrapped__ attribute, that means that the
|
||||
# pipeline added a wrapper around the callback, so we want to look
|
||||
# at the original function instead of the wrapped function.
|
||||
callback = callback.__wrapped__
|
||||
except AttributeError:
|
||||
pass
|
||||
assert callback.call_count == 1
|
||||
callback_arg = callback.call_args[0][0]
|
||||
assert callback_arg == op
|
||||
|
@ -104,6 +112,13 @@ def assert_callback_succeeded(op, callback=None):
|
|||
def assert_callback_failed(op, callback=None, error=None):
|
||||
if not callback:
|
||||
callback = op.callback
|
||||
try:
|
||||
# if the callback has a __wrapped__ attribute, that means that the
|
||||
# pipeline added a wrapper around the callback, so we want to look
|
||||
# at the original function instead of the wrapped function.
|
||||
callback = callback.__wrapped__
|
||||
except AttributeError:
|
||||
pass
|
||||
assert callback.call_count == 1
|
||||
callback_arg = callback.call_args[0][0]
|
||||
assert callback_arg == op
|
||||
|
@ -124,11 +139,45 @@ class UnhandledException(BaseException):
|
|||
def get_arg_count(fn):
|
||||
"""
|
||||
return the number of arguments (args) passed into a
|
||||
particular function. Returned value not include kwargs.
|
||||
particular function. Returned value does not include kwargs.
|
||||
"""
|
||||
return len(getargspec(fn).args)
|
||||
try:
|
||||
# if __wrapped__ is set, we're looking at a decorated function
|
||||
# Functools.wraps doesn't copy arg metadata, so we need to
|
||||
# get argument count from the wrapped function instead.
|
||||
return len(getargspec(fn.__wrapped__).args)
|
||||
except AttributeError:
|
||||
return len(getargspec(fn).args)
|
||||
|
||||
|
||||
def make_mock_op_or_event(cls):
|
||||
args = [None for i in (range(get_arg_count(cls.__init__) - 1))]
|
||||
return cls(*args)
|
||||
|
||||
|
||||
def add_mock_method_waiter(obj, method_name):
|
||||
"""
|
||||
For mock methods, add "wait_for_xxx_to_be_called" and "wait_for_xxx_to_not_be_called"
|
||||
helper functions on the object. This is very handy for methods that get called by
|
||||
another thread, when you want your test functions to wait until the other thread is
|
||||
able to call the method without using a sleep call.
|
||||
"""
|
||||
method_called = Event()
|
||||
|
||||
def signal_method_called(*args, **kwargs):
|
||||
method_called.set()
|
||||
|
||||
def wait_for_method_to_be_called():
|
||||
method_called.wait(0.1)
|
||||
assert method_called.isSet()
|
||||
method_called.clear()
|
||||
|
||||
def wait_for_method_to_not_be_called():
|
||||
method_called.wait(0.1)
|
||||
assert not method_called.isSet()
|
||||
|
||||
getattr(obj, method_name).side_effect = signal_method_called
|
||||
setattr(obj, "wait_for_{}_to_be_called".format(method_name), wait_for_method_to_be_called)
|
||||
setattr(
|
||||
obj, "wait_for_{}_to_not_be_called".format(method_name), wait_for_method_to_not_be_called
|
||||
)
|
||||
|
|
|
@ -5,20 +5,35 @@
|
|||
# --------------------------------------------------------------------------
|
||||
import logging
|
||||
import pytest
|
||||
import inspect
|
||||
import threading
|
||||
import concurrent.futures
|
||||
from tests.common.pipeline.helpers import (
|
||||
all_except,
|
||||
make_mock_stage,
|
||||
make_mock_op_or_event,
|
||||
assert_callback_failed,
|
||||
UnhandledException,
|
||||
get_arg_count,
|
||||
add_mock_method_waiter,
|
||||
)
|
||||
from azure.iot.device.common.pipeline.pipeline_stages_base import PipelineStage
|
||||
from tests.common.pipeline.pipeline_data_object_test import add_instantiation_test
|
||||
from azure.iot.device.common.pipeline import pipeline_thread
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def add_base_pipeline_stage_tests(cls, module, all_ops, handled_ops, all_events, handled_events):
|
||||
def add_base_pipeline_stage_tests(
|
||||
cls,
|
||||
module,
|
||||
all_ops,
|
||||
handled_ops,
|
||||
all_events,
|
||||
handled_events,
|
||||
methods_that_enter_pipeline_thread=[],
|
||||
methods_that_can_run_in_any_thread=[],
|
||||
):
|
||||
"""
|
||||
Add all of the "basic" tests for validating a pipeline stage. This includes tests for
|
||||
instantiation and tests for properly handling "unhandled" operations and events".
|
||||
|
@ -33,6 +48,12 @@ def add_base_pipeline_stage_tests(cls, module, all_ops, handled_ops, all_events,
|
|||
add_unknown_events_tests(
|
||||
cls=cls, module=module, all_events=all_events, handled_events=handled_events
|
||||
)
|
||||
add_pipeline_thread_tests(
|
||||
cls=cls,
|
||||
module=module,
|
||||
methods_that_enter_pipeline_thread=methods_that_enter_pipeline_thread,
|
||||
methods_that_can_run_in_any_thread=methods_that_can_run_in_any_thread,
|
||||
)
|
||||
|
||||
|
||||
def add_unknown_ops_tests(cls, module, all_ops, handled_ops):
|
||||
|
@ -44,13 +65,13 @@ def add_unknown_ops_tests(cls, module, all_ops, handled_ops):
|
|||
unknown_ops = all_except(all_items=all_ops, items_to_exclude=handled_ops)
|
||||
|
||||
@pytest.mark.describe("{} - .run_op() -- unknown and unhandled operations".format(cls.__name__))
|
||||
@pytest.mark.parametrize("op_cls", unknown_ops)
|
||||
class LocalTestObject(object):
|
||||
@pytest.fixture
|
||||
def op(self, op_cls, callback):
|
||||
op = make_mock_op_or_event(op_cls)
|
||||
op.callback = callback
|
||||
op.action = "pend"
|
||||
add_mock_method_waiter(op, "callback")
|
||||
return op
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -58,26 +79,32 @@ def add_unknown_ops_tests(cls, module, all_ops, handled_ops):
|
|||
return make_mock_stage(mocker=mocker, stage_to_make=cls)
|
||||
|
||||
@pytest.mark.it("Passes unknown operation to next stage")
|
||||
@pytest.mark.parametrize("op_cls", unknown_ops)
|
||||
def test_passes_op_to_next_stage(self, op_cls, op, stage):
|
||||
stage.run_op(op)
|
||||
assert stage.next.run_op.call_count == 1
|
||||
assert stage.next.run_op.call_args[0][0] == op
|
||||
|
||||
@pytest.mark.it("Fails unknown operation if there is no next stage")
|
||||
@pytest.mark.parametrize("op_cls", unknown_ops)
|
||||
def test_passes_op_with_no_next_stage(self, op_cls, op, stage):
|
||||
stage.next = None
|
||||
stage.run_op(op)
|
||||
op.wait_for_callback_to_be_called()
|
||||
assert_callback_failed(op=op)
|
||||
|
||||
@pytest.mark.it("Catches Exceptions raised when passing unknown operation to next stage")
|
||||
@pytest.mark.parametrize("op_cls", unknown_ops)
|
||||
def test_passes_op_to_next_stage_which_throws_exception(self, op_cls, op, stage):
|
||||
op.action = "exception"
|
||||
stage.run_op(op)
|
||||
op.wait_for_callback_to_be_called()
|
||||
assert_callback_failed(op=op)
|
||||
|
||||
@pytest.mark.it(
|
||||
"Allows BaseExceptions raised when passing unknown operation to next start to propogate"
|
||||
)
|
||||
@pytest.mark.parametrize("op_cls", unknown_ops)
|
||||
def test_passes_op_to_next_stage_which_throws_base_exception(self, op_cls, op, stage):
|
||||
op.action = "base_exception"
|
||||
with pytest.raises(UnhandledException):
|
||||
|
@ -95,6 +122,9 @@ def add_unknown_events_tests(cls, module, all_events, handled_events):
|
|||
|
||||
unknown_events = all_except(all_items=all_events, items_to_exclude=handled_events)
|
||||
|
||||
if not unknown_events:
|
||||
return
|
||||
|
||||
@pytest.mark.describe(
|
||||
"{} - .handle_pipeline_event() -- unknown and unhandled events".format(cls.__name__)
|
||||
)
|
||||
|
@ -122,16 +152,6 @@ def add_unknown_events_tests(cls, module, all_events, handled_events):
|
|||
stage.previous = previous
|
||||
return previous
|
||||
|
||||
@pytest.fixture
|
||||
def unhandled_error_handler(self, stage, mocker):
|
||||
class MockPipelineRootStage(object):
|
||||
def __init__(self):
|
||||
self.unhandled_error_handler = mocker.MagicMock()
|
||||
|
||||
root = MockPipelineRootStage()
|
||||
stage.pipeline_root = root
|
||||
return root.unhandled_error_handler
|
||||
|
||||
@pytest.mark.it("Passes unknown event to previous stage")
|
||||
def test_passes_event_to_previous_stage(self, event_cls, stage, event, previous):
|
||||
stage.handle_pipeline_event(event)
|
||||
|
@ -167,6 +187,81 @@ def add_unknown_events_tests(cls, module, all_events, handled_events):
|
|||
stage.handle_pipeline_event(event)
|
||||
assert unhandled_error_handler.call_count == 0
|
||||
|
||||
pass
|
||||
|
||||
setattr(module, "Test{}UnknownEvents".format(cls.__name__), LocalTestObject)
|
||||
|
||||
|
||||
class ThreadLaunchedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def add_pipeline_thread_tests(
|
||||
cls, module, methods_that_enter_pipeline_thread, methods_that_can_run_in_any_thread
|
||||
):
|
||||
def does_method_assert_pipeline_thread(method_name):
|
||||
if method_name.startswith("__"):
|
||||
return False
|
||||
elif method_name in methods_that_enter_pipeline_thread:
|
||||
return False
|
||||
elif method_name in methods_that_can_run_in_any_thread:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
methods_that_assert_pipeline_thread = [
|
||||
x[0]
|
||||
for x in inspect.getmembers(cls, inspect.isfunction)
|
||||
if does_method_assert_pipeline_thread(x[0])
|
||||
]
|
||||
|
||||
@pytest.mark.describe("{} - Pipeline threading".format(cls.__name__))
|
||||
class LocalTestObject(object):
|
||||
@pytest.fixture
|
||||
def stage(self):
|
||||
return cls()
|
||||
|
||||
@pytest.mark.parametrize("method_name", methods_that_assert_pipeline_thread)
|
||||
@pytest.mark.it("Enforces use of the pipeline thread when calling method")
|
||||
def test_asserts_in_pipeline(self, stage, method_name, fake_non_pipeline_thread):
|
||||
func = getattr(stage, method_name)
|
||||
args = [None for i in (range(get_arg_count(func) - 1))]
|
||||
with pytest.raises(AssertionError):
|
||||
func(*args)
|
||||
|
||||
if methods_that_enter_pipeline_thread:
|
||||
|
||||
@pytest.mark.parametrize("method_name", methods_that_enter_pipeline_thread)
|
||||
@pytest.mark.it("Automatically enters the pipeline thread when calling method")
|
||||
def test_enters_pipeline(self, mocker, stage, method_name, fake_non_pipeline_thread):
|
||||
func = getattr(stage, method_name)
|
||||
args = [None for i in (range(get_arg_count(func) - 1))]
|
||||
|
||||
#
|
||||
# We take a bit of a roundabout way to verify that the functuion enters the
|
||||
# pipeline executor:
|
||||
#
|
||||
# 1. we verify that the method got the pipeline executor
|
||||
# 2. we verify that the method invoked _something_ on the pipeline executor
|
||||
#
|
||||
# It's not perfect, but it's good enough.
|
||||
#
|
||||
# We do this because:
|
||||
# 1. We don't have the exact right args to run the method and we don't want
|
||||
# to add the complexity to get the right args in this test.
|
||||
# 2. We can't replace the wrapped method with a mock, AFAIK.
|
||||
#
|
||||
pipeline_executor = pipeline_thread._get_named_executor("pipeline")
|
||||
mocker.patch.object(pipeline_executor, "submit")
|
||||
pipeline_executor.submit.side_effect = ThreadLaunchedError
|
||||
mocker.spy(pipeline_thread, "_get_named_executor")
|
||||
|
||||
# If the method calls submit on some executor, it will raise a ThreadLaunchedError
|
||||
with pytest.raises(ThreadLaunchedError):
|
||||
func(*args)
|
||||
|
||||
# now verify that the code got the pipeline executor and verify that it used that
|
||||
# executor to launch something.
|
||||
assert pipeline_thread._get_named_executor.call_count == 1
|
||||
assert pipeline_thread._get_named_executor.call_args[0][0] == "pipeline"
|
||||
assert pipeline_executor.submit.call_count == 1
|
||||
|
||||
setattr(module, "Test{}PipelineThreading".format(cls.__name__), LocalTestObject)
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
import logging
|
||||
import pytest
|
||||
from azure.iot.device.common.pipeline import (
|
||||
pipeline_thread,
|
||||
pipeline_stages_base,
|
||||
pipeline_ops_base,
|
||||
pipeline_events_base,
|
||||
|
@ -26,6 +27,16 @@ from tests.common.pipeline.helpers import (
|
|||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
# This fixture makes it look like all test in this file tests are running
|
||||
# inside the pipeline thread. Because this is an autouse fixture, we
|
||||
# manually add it to the individual test.py files that need it. If,
|
||||
# instead, we had added it to some conftest.py, it would be applied to
|
||||
# every tests in every file and we don't want that.
|
||||
@pytest.fixture(autouse=True)
|
||||
def apply_fake_pipeline_thread(fake_pipeline_thread):
|
||||
pass
|
||||
|
||||
|
||||
class MockPipelineStage(pipeline_stages_base.PipelineStage):
|
||||
def _run_op(self, op):
|
||||
pass_op_to_next_stage(self, op)
|
||||
|
@ -79,13 +90,15 @@ class TestRunOpsSerialOneOpButNoFinallyOp(object):
|
|||
@pytest.mark.it(
|
||||
"Handles Exceptions raised in the callback and passes them to the unhandled error handler"
|
||||
)
|
||||
def test_callback_throws_exception(self, stage, mocker, fake_exception, op):
|
||||
def test_callback_throws_exception(
|
||||
self, stage, mocker, fake_exception, op, unhandled_error_handler
|
||||
):
|
||||
callback = mocker.Mock(side_effect=fake_exception)
|
||||
run_ops_in_serial(stage, op, callback=callback)
|
||||
assert callback.call_count == 1
|
||||
assert callback.call_args == mocker.call(op)
|
||||
assert stage.unhandled_error_handler.call_count == 1
|
||||
assert stage.unhandled_error_handler.call_args == mocker.call(fake_exception)
|
||||
assert unhandled_error_handler.call_count == 1
|
||||
assert unhandled_error_handler.call_args == mocker.call(fake_exception)
|
||||
|
||||
@pytest.mark.it("Allows any BaseExceptions raised in the callback to propagate")
|
||||
def test_callback_throws_base_exception(self, stage, mocker, fake_base_exception, op):
|
||||
|
@ -153,13 +166,15 @@ class TestRunOpsSerialOneOpAndFinallyOp(object):
|
|||
@pytest.mark.it(
|
||||
"Handles Exceptions raised in the callback and passes them to the unhandled error handler"
|
||||
)
|
||||
def test_callback_raises_exception(self, stage, op, finally_op, fake_exception, mocker):
|
||||
def test_callback_raises_exception(
|
||||
self, stage, op, finally_op, fake_exception, mocker, unhandled_error_handler
|
||||
):
|
||||
callback = mocker.Mock(side_effect=fake_exception)
|
||||
run_ops_in_serial(stage, op, finally_op=finally_op, callback=callback)
|
||||
assert callback.call_count == 1
|
||||
assert callback.call_args == mocker.call(finally_op)
|
||||
assert stage.unhandled_error_handler.call_count == 1
|
||||
assert stage.unhandled_error_handler.call_args == mocker.call(fake_exception)
|
||||
assert unhandled_error_handler.call_count == 1
|
||||
assert unhandled_error_handler.call_args == mocker.call(fake_exception)
|
||||
|
||||
@pytest.mark.it("Allows any BaseExceptions raised in the callback to propagate")
|
||||
def test_callback_raises_base_exception(
|
||||
|
@ -197,13 +212,15 @@ class TestRunOpsSerialThreeOpsButNoFinallyOp(object):
|
|||
@pytest.mark.it(
|
||||
"Handles Exceptions raised in the callback and passes them to the unhandled error handler"
|
||||
)
|
||||
def test_callback_raises_exception(self, stage, op, op2, op3, fake_exception, mocker):
|
||||
def test_callback_raises_exception(
|
||||
self, stage, op, op2, op3, fake_exception, mocker, unhandled_error_handler
|
||||
):
|
||||
callback = mocker.Mock(side_effect=fake_exception)
|
||||
run_ops_in_serial(stage, op, op2, op3, callback=callback)
|
||||
assert callback.call_count == 1
|
||||
assert callback.call_args == mocker.call(op3)
|
||||
assert stage.unhandled_error_handler.call_count == 1
|
||||
assert stage.unhandled_error_handler.call_args == mocker.call(fake_exception)
|
||||
assert unhandled_error_handler.call_count == 1
|
||||
assert unhandled_error_handler.call_args == mocker.call(fake_exception)
|
||||
|
||||
@pytest.mark.it("Allows any BaseExceptions raised in the callback to propagate")
|
||||
def test_callback_raises_base_exception(self, stage, op, op2, op3, fake_base_exception, mocker):
|
||||
|
@ -305,12 +322,12 @@ class TestRunOpsSerialThreeOpsAndFinallyOp(object):
|
|||
"Handles Exceptions raised in the callback and passes them to the unhandled error handler"
|
||||
)
|
||||
def test_callback_raises_exception(
|
||||
self, stage, op, op2, op3, finally_op, fake_exception, mocker
|
||||
self, stage, op, op2, op3, finally_op, fake_exception, mocker, unhandled_error_handler
|
||||
):
|
||||
callback = mocker.Mock(side_effect=fake_exception)
|
||||
run_ops_in_serial(stage, op, op2, op3, callback=callback, finally_op=finally_op)
|
||||
assert stage.unhandled_error_handler.call_count == 1
|
||||
assert stage.unhandled_error_handler.call_args == mocker.call(fake_exception)
|
||||
assert unhandled_error_handler.call_count == 1
|
||||
assert unhandled_error_handler.call_args == mocker.call(fake_exception)
|
||||
|
||||
@pytest.mark.it("Allows any BaseExceptions raised in the callback to propagate")
|
||||
def test_callback_raises_base_exception(
|
||||
|
@ -505,13 +522,15 @@ class TestCompleteOp(object):
|
|||
@pytest.mark.it(
|
||||
"Handles Exceptions raised in operation callback and passes them to the unhandled error handler"
|
||||
)
|
||||
def test_op_callback_raises_exception(self, stage, op, fake_exception, mocker):
|
||||
def test_op_callback_raises_exception(
|
||||
self, stage, op, fake_exception, mocker, unhandled_error_handler
|
||||
):
|
||||
op.callback = mocker.Mock(side_effect=fake_exception)
|
||||
complete_op(stage, op)
|
||||
assert op.callback.call_count == 1
|
||||
assert op.callback.call_args == mocker.call(op)
|
||||
assert stage.unhandled_error_handler.call_count == 1
|
||||
assert stage.unhandled_error_handler.call_args == mocker.call(fake_exception)
|
||||
assert unhandled_error_handler.call_count == 1
|
||||
assert unhandled_error_handler.call_args == mocker.call(fake_exception)
|
||||
|
||||
@pytest.mark.it("Allows any BaseExceptions raised in operation callback to propagate")
|
||||
def test_op_callback_raises_base_exception(self, stage, op, fake_base_exception, mocker):
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
import logging
|
||||
import pytest
|
||||
import sys
|
||||
import threading
|
||||
from azure.iot.device.common.pipeline import (
|
||||
pipeline_stages_base,
|
||||
pipeline_ops_base,
|
||||
|
@ -26,6 +27,17 @@ logging.basicConfig(level=logging.INFO)
|
|||
|
||||
this_module = sys.modules[__name__]
|
||||
|
||||
|
||||
# This fixture makes it look like all test in this file tests are running
|
||||
# inside the pipeline thread. Because this is an autouse fixture, we
|
||||
# manually add it to the individual test.py files that need it. If,
|
||||
# instead, we had added it to some conftest.py, it would be applied to
|
||||
# every tests in every file and we don't want that.
|
||||
@pytest.fixture(autouse=True)
|
||||
def apply_fake_pipeline_thread(fake_pipeline_thread):
|
||||
pass
|
||||
|
||||
|
||||
pipeline_stage_test.add_base_pipeline_stage_tests(
|
||||
cls=pipeline_stages_base.EnsureConnectionStage,
|
||||
module=this_module,
|
||||
|
@ -42,6 +54,68 @@ pipeline_stage_test.add_base_pipeline_stage_tests(
|
|||
handled_events=[],
|
||||
)
|
||||
|
||||
|
||||
# Workaround for flake8. A class with this name is actually created inside
|
||||
# add_base_pipeline_stage_test, but flake8 doesn't know that
|
||||
class TestPipelineRootStagePipelineThreading:
|
||||
pass
|
||||
|
||||
|
||||
pipeline_stage_test.add_base_pipeline_stage_tests(
|
||||
cls=pipeline_stages_base.PipelineRootStage,
|
||||
module=this_module,
|
||||
all_ops=all_common_ops,
|
||||
handled_ops=[],
|
||||
all_events=all_common_events,
|
||||
handled_events=all_common_events,
|
||||
methods_that_can_run_in_any_thread=["append_stage", "run_op"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.it("Calls operation callback in callback thread")
|
||||
def _test_pipeline_root_runs_callback_in_callback_thread(self, stage, mocker):
|
||||
# the stage fixture comes from the TestPipelineRootStagePipelineThreading object that
|
||||
# this test method gets added to, so it's a PipelineRootStage object
|
||||
stage.pipeline_root = stage
|
||||
callback_called = threading.Event()
|
||||
|
||||
def callback(op):
|
||||
assert threading.current_thread().name == "callback"
|
||||
callback_called.set()
|
||||
|
||||
op = pipeline_ops_base.ConnectOperation(callback=callback)
|
||||
stage.run_op(op)
|
||||
callback_called.wait()
|
||||
|
||||
|
||||
@pytest.mark.it("Runs operation in pipeline thread")
|
||||
def _test_pipeline_root_runs_operation_in_pipeline_thread(
|
||||
self, mocker, stage, op, fake_non_pipeline_thread
|
||||
):
|
||||
# the stage fixture comes from the TestPipelineRootStagePipelineThreading object that
|
||||
# this test method gets added to, so it's a PipelineRootStage object
|
||||
assert threading.current_thread().name != "pipeline"
|
||||
|
||||
def mock_run_op(self, op):
|
||||
print("mock_run_op called")
|
||||
assert threading.current_thread().name == "pipeline"
|
||||
op.callback(op)
|
||||
|
||||
mock_run_op = mocker.MagicMock(mock_run_op)
|
||||
stage._run_op = mock_run_op
|
||||
|
||||
stage.run_op(op)
|
||||
assert mock_run_op.call_count == 1
|
||||
|
||||
|
||||
TestPipelineRootStagePipelineThreading.test_runs_callback_in_callback_thread = (
|
||||
_test_pipeline_root_runs_callback_in_callback_thread
|
||||
)
|
||||
TestPipelineRootStagePipelineThreading.test_runs_operation_in_pipeline_thread = (
|
||||
_test_pipeline_root_runs_operation_in_pipeline_thread
|
||||
)
|
||||
|
||||
|
||||
pipeline_stage_test.add_base_pipeline_stage_tests(
|
||||
cls=pipeline_stages_base.CoordinateRequestAndResponseStage,
|
||||
module=this_module,
|
||||
|
@ -171,27 +245,29 @@ class TestCoordinateRequestAndResponseSendIotRequestHandleEvent(object):
|
|||
@pytest.mark.it(
|
||||
"Calls the unhandled error handler if there is no previous stage when request_id matches"
|
||||
)
|
||||
def test_matching_request_id_with_no_previous_stage(self, stage, op, iot_response):
|
||||
def test_matching_request_id_with_no_previous_stage(
|
||||
self, stage, op, iot_response, unhandled_error_handler
|
||||
):
|
||||
stage.next.previous = None
|
||||
operation_flow.pass_event_to_previous_stage(stage.next, iot_response)
|
||||
assert stage.pipeline_root.unhandled_error_handler.call_count == 1
|
||||
assert unhandled_error_handler.call_count == 1
|
||||
|
||||
@pytest.mark.it(
|
||||
"Does nothing if an IotResponse with an identical request_id is received a second time"
|
||||
)
|
||||
def test_ignores_duplicate_request_id(self, stage, op, iot_response):
|
||||
def test_ignores_duplicate_request_id(self, stage, op, iot_response, unhandled_error_handler):
|
||||
operation_flow.pass_event_to_previous_stage(stage.next, iot_response)
|
||||
assert_callback_succeeded(op=op)
|
||||
op.callback.reset_mock()
|
||||
|
||||
operation_flow.pass_event_to_previous_stage(stage.next, iot_response)
|
||||
assert op.callback.call_count == 0
|
||||
assert stage.pipeline_root.unhandled_error_handler.call_count == 0
|
||||
assert unhandled_error_handler.call_count == 0
|
||||
|
||||
@pytest.mark.it(
|
||||
"Does nothing if an IotResponse with a request_id is received for an operation that returned failure"
|
||||
)
|
||||
def test_ignores_request_id_from_failure(self, stage, op, mocker):
|
||||
def test_ignores_request_id_from_failure(self, stage, op, mocker, unhandled_error_handler):
|
||||
stage.next._run_op = mocker.MagicMock(side_effect=Exception)
|
||||
stage.run_op(op)
|
||||
|
||||
|
@ -205,11 +281,11 @@ class TestCoordinateRequestAndResponseSendIotRequestHandleEvent(object):
|
|||
op.callback.reset_mock()
|
||||
operation_flow.pass_event_to_previous_stage(stage.next, resp)
|
||||
assert op.callback.call_count == 0
|
||||
assert stage.pipeline_root.unhandled_error_handler.call_count == 0
|
||||
assert unhandled_error_handler.call_count == 0
|
||||
|
||||
@pytest.mark.it("Does nothing if an IotResponse with an unknown request_id is received")
|
||||
def test_ignores_unknown_request_id(self, stage, op, iot_response):
|
||||
def test_ignores_unknown_request_id(self, stage, op, iot_response, unhandled_error_handler):
|
||||
iot_response.request_id = fake_request_id
|
||||
operation_flow.pass_event_to_previous_stage(stage.next, iot_response)
|
||||
assert op.callback.call_count == 0
|
||||
assert stage.pipeline_root.unhandled_error_handler.call_count == 0
|
||||
assert unhandled_error_handler.call_count == 0
|
||||
|
|
|
@ -25,6 +25,17 @@ from tests.common.pipeline import pipeline_stage_test
|
|||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
# This fixture makes it look like all test in this file tests are running
|
||||
# inside the pipeline thread. Because this is an autouse fixture, we
|
||||
# manually add it to the individual test.py files that need it. If,
|
||||
# instead, we had added it to some conftest.py, it would be applied to
|
||||
# every tests in every file and we don't want that.
|
||||
@pytest.fixture(autouse=True)
|
||||
def apply_fake_pipeline_thread(fake_pipeline_thread):
|
||||
pass
|
||||
|
||||
|
||||
this_module = sys.modules[__name__]
|
||||
|
||||
fake_client_id = "__fake_client_id__"
|
||||
|
@ -57,6 +68,7 @@ pipeline_stage_test.add_base_pipeline_stage_tests(
|
|||
handled_ops=ops_handled_by_this_stage,
|
||||
all_events=all_common_events,
|
||||
handled_events=events_handled_by_this_stage,
|
||||
methods_that_enter_pipeline_thread=["_on_message_received", "on_connected", "on_disconnected"],
|
||||
)
|
||||
|
||||
|
||||
|
@ -150,7 +162,7 @@ def op_set_sas_token(callback):
|
|||
@pytest.fixture
|
||||
def op_set_client_certificate(callback):
|
||||
return pipeline_ops_base.SetClientAuthenticationCertificateOperation(
|
||||
certificate=fake_certificate
|
||||
certificate=fake_certificate, callback=callback
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -4,4 +4,12 @@
|
|||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from tests.common.pipeline.fixtures import callback, fake_exception, fake_base_exception, event
|
||||
from tests.common.pipeline.fixtures import (
|
||||
callback,
|
||||
fake_exception,
|
||||
fake_base_exception,
|
||||
event,
|
||||
fake_pipeline_thread,
|
||||
fake_non_pipeline_thread,
|
||||
unhandled_error_handler,
|
||||
)
|
||||
|
|
|
@ -29,6 +29,16 @@ logging.basicConfig(level=logging.INFO)
|
|||
this_module = sys.modules[__name__]
|
||||
|
||||
|
||||
# This fixture makes it look like all test in this file tests are running
|
||||
# inside the pipeline thread. Because this is an autouse fixture, we
|
||||
# manually add it to the individual test.py files that need it. If,
|
||||
# instead, we had added it to some conftest.py, it would be applied to
|
||||
# every tests in every file and we don't want that.
|
||||
@pytest.fixture(autouse=True)
|
||||
def apply_fake_pipeline_thread(fake_pipeline_thread):
|
||||
pass
|
||||
|
||||
|
||||
fake_device_id = "__fake_device_id__"
|
||||
fake_module_id = "__fake_module_id__"
|
||||
fake_hostname = "__fake_hostname__"
|
||||
|
|
|
@ -39,6 +39,17 @@ logging.basicConfig(level=logging.INFO)
|
|||
|
||||
this_module = sys.modules[__name__]
|
||||
|
||||
|
||||
# This fixture makes it look like all test in this file tests are running
|
||||
# inside the pipeline thread. Because this is an autouse fixture, we
|
||||
# manually add it to the individual test.py files that need it. If,
|
||||
# instead, we had added it to some conftest.py, it would be applied to
|
||||
# every tests in every file and we don't want that.
|
||||
@pytest.fixture(autouse=True)
|
||||
def apply_fake_pipeline_thread(fake_pipeline_thread):
|
||||
pass
|
||||
|
||||
|
||||
fake_device_id = "__fake_device_id__"
|
||||
fake_module_id = "__fake_module_id__"
|
||||
fake_hostname = "__fake_hostname__"
|
||||
|
@ -509,7 +520,6 @@ class TestIoTHubMQTTConverterWithEnableFeature(object):
|
|||
def add_pipeline_root(stage, mocker):
|
||||
root = pipeline_stages_base.PipelineRootStage()
|
||||
mocker.spy(root, "handle_pipeline_event")
|
||||
mocker.spy(root, "unhandled_error_handler")
|
||||
stage.previous = root
|
||||
stage.pipeline_root = root
|
||||
|
||||
|
@ -895,46 +905,61 @@ class TestIotHubMQTTConverterHandlePipelineEventTwinResponse(object):
|
|||
assert new_event.response_body == fake_payload
|
||||
|
||||
@pytest.mark.it("Calls the unhandled exception handler if there is no previous stage")
|
||||
def test_no_previous_stage(self, stage, fixup_stage_for_test, fake_event):
|
||||
def test_no_previous_stage(
|
||||
self, stage, fixup_stage_for_test, fake_event, unhandled_error_handler
|
||||
):
|
||||
stage.previous = None
|
||||
stage.handle_pipeline_event(fake_event)
|
||||
assert stage.pipeline_root.unhandled_error_handler.call_count == 1
|
||||
assert isinstance(
|
||||
stage.pipeline_root.unhandled_error_handler.call_args[0][0], NotImplementedError
|
||||
)
|
||||
assert unhandled_error_handler.call_count == 1
|
||||
assert isinstance(unhandled_error_handler.call_args[0][0], NotImplementedError)
|
||||
|
||||
@pytest.mark.it(
|
||||
"Calls the unhandled exception handler if the requet_id is missing from the topic name"
|
||||
)
|
||||
def test_invalid_topic_with_missing_request_id(
|
||||
self, stage, fixup_stage_for_test, fake_event, fake_topic_name_with_missing_request_id
|
||||
self,
|
||||
stage,
|
||||
fixup_stage_for_test,
|
||||
fake_event,
|
||||
fake_topic_name_with_missing_request_id,
|
||||
unhandled_error_handler,
|
||||
):
|
||||
fake_event.topic = fake_topic_name_with_missing_request_id
|
||||
stage.handle_pipeline_event(event=fake_event)
|
||||
assert stage.pipeline_root.unhandled_error_handler.call_count == 1
|
||||
assert isinstance(stage.pipeline_root.unhandled_error_handler.call_args[0][0], IndexError)
|
||||
assert unhandled_error_handler.call_count == 1
|
||||
assert isinstance(unhandled_error_handler.call_args[0][0], IndexError)
|
||||
|
||||
@pytest.mark.it(
|
||||
"Calls the unhandled exception handler if the status code is missing from the topic name"
|
||||
)
|
||||
def test_invlid_topic_with_missing_status_code(
|
||||
self, stage, fixup_stage_for_test, fake_event, fake_topic_name_with_missing_status_code
|
||||
self,
|
||||
stage,
|
||||
fixup_stage_for_test,
|
||||
fake_event,
|
||||
fake_topic_name_with_missing_status_code,
|
||||
unhandled_error_handler,
|
||||
):
|
||||
fake_event.topic = fake_topic_name_with_missing_status_code
|
||||
stage.handle_pipeline_event(event=fake_event)
|
||||
assert stage.pipeline_root.unhandled_error_handler.call_count == 1
|
||||
assert isinstance(stage.pipeline_root.unhandled_error_handler.call_args[0][0], ValueError)
|
||||
assert unhandled_error_handler.call_count == 1
|
||||
assert isinstance(unhandled_error_handler.call_args[0][0], ValueError)
|
||||
|
||||
@pytest.mark.it(
|
||||
"Calls the unhandled exception handler if the status code in the topic name is not numeric"
|
||||
)
|
||||
def test_invlid_topic_with_bad_status_code(
|
||||
self, stage, fixup_stage_for_test, fake_event, fake_topic_name_with_bad_status_code
|
||||
self,
|
||||
stage,
|
||||
fixup_stage_for_test,
|
||||
fake_event,
|
||||
fake_topic_name_with_bad_status_code,
|
||||
unhandled_error_handler,
|
||||
):
|
||||
fake_event.topic = fake_topic_name_with_bad_status_code
|
||||
stage.handle_pipeline_event(event=fake_event)
|
||||
assert stage.pipeline_root.unhandled_error_handler.call_count == 1
|
||||
assert isinstance(stage.pipeline_root.unhandled_error_handler.call_args[0][0], ValueError)
|
||||
assert unhandled_error_handler.call_count == 1
|
||||
assert isinstance(unhandled_error_handler.call_args[0][0], ValueError)
|
||||
|
||||
|
||||
@pytest.mark.describe(
|
||||
|
@ -984,30 +1009,34 @@ class TestIotHubMQTTConverterHandlePipelineEventTwinPatch(object):
|
|||
assert new_event.patch == fake_patch
|
||||
|
||||
@pytest.mark.it("Calls the unhandled exception handler if there is no previous stage")
|
||||
def test_no_previous_stage(self, stage, fixup_stage_for_test, fake_event):
|
||||
def test_no_previous_stage(
|
||||
self, stage, fixup_stage_for_test, fake_event, unhandled_error_handler
|
||||
):
|
||||
stage.previous = None
|
||||
stage.handle_pipeline_event(fake_event)
|
||||
assert stage.pipeline_root.unhandled_error_handler.call_count == 1
|
||||
assert isinstance(
|
||||
stage.pipeline_root.unhandled_error_handler.call_args[0][0], NotImplementedError
|
||||
)
|
||||
assert unhandled_error_handler.call_count == 1
|
||||
assert isinstance(unhandled_error_handler.call_args[0][0], NotImplementedError)
|
||||
|
||||
@pytest.mark.it("Calls the unhandled exception handler if the payload is not a Bytes object")
|
||||
def test_payload_not_bytes(self, stage, fixup_stage_for_test, fake_event, fake_patch_not_bytes):
|
||||
def test_payload_not_bytes(
|
||||
self, stage, fixup_stage_for_test, fake_event, fake_patch_not_bytes, unhandled_error_handler
|
||||
):
|
||||
fake_event.payload = fake_patch_not_bytes
|
||||
stage.handle_pipeline_event(fake_event)
|
||||
assert stage.pipeline_root.unhandled_error_handler.call_count == 1
|
||||
assert unhandled_error_handler.call_count == 1
|
||||
if not (
|
||||
isinstance(stage.pipeline_root.unhandled_error_handler.call_args[0][0], AttributeError)
|
||||
or isinstance(stage.pipeline_root.unhandled_error_handler.call_args[0][0], ValueError)
|
||||
isinstance(unhandled_error_handler.call_args[0][0], AttributeError)
|
||||
or isinstance(unhandled_error_handler.call_args[0][0], ValueError)
|
||||
):
|
||||
assert False
|
||||
|
||||
@pytest.mark.it(
|
||||
"Calls the unhandled exception handler if the payload cannot be deserialized as a JSON object"
|
||||
)
|
||||
def test_payload_not_json(self, stage, fixup_stage_for_test, fake_event, fake_patch_not_json):
|
||||
def test_payload_not_json(
|
||||
self, stage, fixup_stage_for_test, fake_event, fake_patch_not_json, unhandled_error_handler
|
||||
):
|
||||
fake_event.payload = fake_patch_not_json
|
||||
stage.handle_pipeline_event(fake_event)
|
||||
assert stage.pipeline_root.unhandled_error_handler.call_count == 1
|
||||
assert isinstance(stage.pipeline_root.unhandled_error_handler.call_args[0][0], ValueError)
|
||||
assert unhandled_error_handler.call_count == 1
|
||||
assert isinstance(unhandled_error_handler.call_args[0][0], ValueError)
|
||||
|
|
|
@ -4,4 +4,12 @@
|
|||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from tests.common.pipeline.fixtures import callback, fake_exception, fake_base_exception, event
|
||||
from tests.common.pipeline.fixtures import (
|
||||
callback,
|
||||
fake_exception,
|
||||
fake_base_exception,
|
||||
event,
|
||||
fake_pipeline_thread,
|
||||
fake_non_pipeline_thread,
|
||||
unhandled_error_handler,
|
||||
)
|
||||
|
|
|
@ -33,6 +33,13 @@ logging.basicConfig(level=logging.INFO)
|
|||
|
||||
this_module = sys.modules[__name__]
|
||||
|
||||
|
||||
# Make it look like we're always running inside pipeline threads
|
||||
@pytest.fixture(autouse=True)
|
||||
def apply_fake_pipeline_thread(fake_pipeline_thread):
|
||||
pass
|
||||
|
||||
|
||||
fake_device_id = "elder_wand"
|
||||
fake_registration_id = "registered_remembrall"
|
||||
fake_provisioning_host = "hogwarts.com"
|
||||
|
|
|
@ -34,6 +34,17 @@ logging.basicConfig(level=logging.INFO)
|
|||
|
||||
this_module = sys.modules[__name__]
|
||||
|
||||
|
||||
# This fixture makes it look like all test in this file tests are running
|
||||
# inside the pipeline thread. Because this is an autouse fixture, we
|
||||
# manually add it to the individual test.py files that need it. If,
|
||||
# instead, we had added it to some conftest.py, it would be applied to
|
||||
# every tests in every file and we don't want that.
|
||||
@pytest.fixture(autouse=True)
|
||||
def apply_fake_pipeline_thread(fake_pipeline_thread):
|
||||
pass
|
||||
|
||||
|
||||
fake_device_id = "elder_wand"
|
||||
fake_registration_id = "registered_remembrall"
|
||||
fake_provisioning_host = "hogwarts.com"
|
||||
|
|
|
@ -10,6 +10,7 @@ from azure.iot.device.common.models import X509
|
|||
from azure.iot.device.provisioning.security.sk_security_client import SymmetricKeySecurityClient
|
||||
from azure.iot.device.provisioning.security.x509_security_client import X509SecurityClient
|
||||
from azure.iot.device.provisioning.pipeline.provisioning_pipeline import ProvisioningPipeline
|
||||
from tests.common.pipeline import helpers
|
||||
|
||||
send_msg_qos = 1
|
||||
|
||||
|
@ -110,6 +111,10 @@ def mock_provisioning_pipeline(params_security_clients):
|
|||
provisioning_pipeline.on_connected = MagicMock()
|
||||
provisioning_pipeline.on_disconnected = MagicMock()
|
||||
provisioning_pipeline.on_message_received = MagicMock()
|
||||
helpers.add_mock_method_waiter(provisioning_pipeline, "on_connected")
|
||||
helpers.add_mock_method_waiter(provisioning_pipeline, "on_disconnected")
|
||||
helpers.add_mock_method_waiter(provisioning_pipeline._pipeline.transport, "publish")
|
||||
|
||||
yield provisioning_pipeline
|
||||
provisioning_pipeline.disconnect()
|
||||
|
||||
|
@ -148,6 +153,7 @@ class TestConnect(object):
|
|||
assert_for_client_x509(mock_mqtt_transport.connect.call_args[1]["client_certificate"])
|
||||
|
||||
mock_mqtt_transport.on_mqtt_connected()
|
||||
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
|
||||
|
||||
@pytest.mark.it("After complete calls handler with new state")
|
||||
def test_connected_state_handler_called_wth_new_state_once_provider_gets_connected(
|
||||
|
@ -157,6 +163,7 @@ class TestConnect(object):
|
|||
|
||||
mock_provisioning_pipeline.connect()
|
||||
mock_mqtt_transport.on_mqtt_connected()
|
||||
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
|
||||
|
||||
mock_provisioning_pipeline.on_connected.assert_called_once_with("connected")
|
||||
|
||||
|
@ -169,6 +176,7 @@ class TestConnect(object):
|
|||
mock_provisioning_pipeline.connect()
|
||||
mock_provisioning_pipeline.connect()
|
||||
mock_mqtt_transport.on_mqtt_connected()
|
||||
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
|
||||
|
||||
assert mock_mqtt_transport.connect.call_count == 1
|
||||
|
||||
|
@ -187,6 +195,7 @@ class TestConnect(object):
|
|||
|
||||
mock_provisioning_pipeline.connect()
|
||||
mock_mqtt_transport.on_mqtt_connected()
|
||||
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
|
||||
|
||||
mock_mqtt_transport.reset_mock()
|
||||
mock_provisioning_pipeline.on_connected.reset_mock()
|
||||
|
@ -197,11 +206,13 @@ class TestConnect(object):
|
|||
mock_provisioning_pipeline.connect()
|
||||
|
||||
mock_mqtt_transport.connect.assert_not_called()
|
||||
mock_provisioning_pipeline.wait_for_on_connected_to_not_be_called()
|
||||
mock_provisioning_pipeline.on_connected.assert_not_called()
|
||||
|
||||
mock_mqtt_transport.on_mqtt_published(0)
|
||||
|
||||
mock_mqtt_transport.connect.assert_not_called()
|
||||
mock_provisioning_pipeline.wait_for_on_connected_to_not_be_called()
|
||||
mock_provisioning_pipeline.on_connected.assert_not_called()
|
||||
|
||||
|
||||
|
@ -216,6 +227,7 @@ class TestSendRegister(object):
|
|||
|
||||
mock_provisioning_pipeline.connect()
|
||||
mock_mqtt_transport.on_mqtt_connected()
|
||||
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
|
||||
mock_provisioning_pipeline.send_request(
|
||||
request_id=fake_request_id, request_payload=fake_mqtt_payload
|
||||
)
|
||||
|
@ -233,6 +245,7 @@ class TestSendRegister(object):
|
|||
fake_request_id
|
||||
)
|
||||
|
||||
mock_mqtt_transport.wait_for_publish_to_be_called()
|
||||
assert mock_mqtt_transport.publish.call_count == 1
|
||||
assert mock_mqtt_transport.publish.call_args[1]["topic"] == fake_publish_topic
|
||||
assert mock_mqtt_transport.publish.call_args[1]["payload"] == fake_mqtt_payload
|
||||
|
@ -259,11 +272,14 @@ class TestSendRegister(object):
|
|||
assert_for_client_x509(mock_mqtt_transport.connect.call_args[1]["client_certificate"])
|
||||
|
||||
# verify that we're not connected yet and verify that we havent't published yet
|
||||
mock_provisioning_pipeline.wait_for_on_connected_to_not_be_called()
|
||||
mock_provisioning_pipeline.on_connected.assert_not_called()
|
||||
mock_mqtt_transport.wait_for_publish_to_not_be_called()
|
||||
mock_mqtt_transport.publish.assert_not_called()
|
||||
|
||||
# finish the connection
|
||||
mock_mqtt_transport.on_mqtt_connected()
|
||||
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
|
||||
|
||||
# verify that our connected callback was called and verify that we published the event
|
||||
mock_provisioning_pipeline.on_connected.assert_called_once_with("connected")
|
||||
|
@ -272,6 +288,7 @@ class TestSendRegister(object):
|
|||
fake_request_id
|
||||
)
|
||||
|
||||
mock_mqtt_transport.wait_for_publish_to_be_called()
|
||||
assert mock_mqtt_transport.publish.call_count == 1
|
||||
assert mock_mqtt_transport.publish.call_args[1]["topic"] == fake_publish_topic
|
||||
assert mock_mqtt_transport.publish.call_args[1]["payload"] == fake_mqtt_payload
|
||||
|
@ -299,17 +316,21 @@ class TestSendRegister(object):
|
|||
)
|
||||
|
||||
# verify that we're not connected yet and verify that we havent't published yet
|
||||
mock_provisioning_pipeline.wait_for_on_connected_to_not_be_called()
|
||||
mock_provisioning_pipeline.on_connected.assert_not_called()
|
||||
mock_mqtt_transport.wait_for_publish_to_not_be_called()
|
||||
mock_mqtt_transport.publish.assert_not_called()
|
||||
|
||||
# finish the connection
|
||||
mock_mqtt_transport.on_mqtt_connected()
|
||||
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
|
||||
|
||||
# verify that our connected callback was called and verify that we published the event
|
||||
mock_provisioning_pipeline.on_connected.assert_called_once_with("connected")
|
||||
fake_publish_topic = "$dps/registrations/PUT/iotdps-register/?$rid={}".format(
|
||||
fake_request_id
|
||||
)
|
||||
mock_mqtt_transport.wait_for_publish_to_be_called()
|
||||
assert mock_mqtt_transport.publish.call_count == 1
|
||||
assert mock_mqtt_transport.publish.call_args[1]["topic"] == fake_publish_topic
|
||||
assert mock_mqtt_transport.publish.call_args[1]["payload"] == fake_mqtt_payload
|
||||
|
@ -326,6 +347,7 @@ class TestSendRegister(object):
|
|||
# connect
|
||||
mock_provisioning_pipeline.connect()
|
||||
mock_mqtt_transport.on_mqtt_connected()
|
||||
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
|
||||
|
||||
# send an event
|
||||
callback_1 = MagicMock()
|
||||
|
@ -336,6 +358,7 @@ class TestSendRegister(object):
|
|||
fake_publish_topic = "$dps/registrations/PUT/iotdps-register/?$rid={}".format(
|
||||
fake_request_id_1
|
||||
)
|
||||
mock_mqtt_transport.wait_for_publish_to_be_called()
|
||||
assert mock_mqtt_transport.publish.call_count == 1
|
||||
assert mock_mqtt_transport.publish.call_args[1]["topic"] == fake_publish_topic
|
||||
assert mock_mqtt_transport.publish.call_args[1]["payload"] == fake_msg_1
|
||||
|
@ -349,6 +372,7 @@ class TestSendRegister(object):
|
|||
|
||||
# verify that we've called publish twice and verify that neither send_d2c_message
|
||||
# has completed (because we didn't do anything here to complete it).
|
||||
mock_mqtt_transport.wait_for_publish_to_be_called()
|
||||
assert mock_mqtt_transport.publish.call_count == 2
|
||||
callback_1.assert_not_called()
|
||||
callback_2.assert_not_called()
|
||||
|
@ -360,6 +384,7 @@ class TestSendRegister(object):
|
|||
# connect
|
||||
mock_provisioning_pipeline.connect()
|
||||
mock_mqtt_transport.on_mqtt_connected()
|
||||
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
|
||||
|
||||
# send an event
|
||||
mock_provisioning_pipeline.send_request(
|
||||
|
@ -383,6 +408,7 @@ class TestSendQuery(object):
|
|||
|
||||
mock_provisioning_pipeline.connect()
|
||||
mock_mqtt_transport.on_mqtt_connected()
|
||||
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
|
||||
mock_provisioning_pipeline.send_request(
|
||||
request_id=fake_request_id,
|
||||
request_payload=fake_mqtt_payload,
|
||||
|
@ -402,6 +428,7 @@ class TestSendQuery(object):
|
|||
fake_request_id, fake_operation_id
|
||||
)
|
||||
|
||||
mock_mqtt_transport.wait_for_publish_to_be_called()
|
||||
assert mock_mqtt_transport.publish.call_count == 1
|
||||
assert mock_mqtt_transport.publish.call_args[1]["topic"] == fake_publish_topic
|
||||
assert mock_mqtt_transport.publish.call_args[1]["payload"] == fake_mqtt_payload
|
||||
|
@ -416,6 +443,7 @@ class TestDisconnect(object):
|
|||
|
||||
mock_provisioning_pipeline.connect()
|
||||
mock_mqtt_transport.on_mqtt_connected()
|
||||
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
|
||||
mock_provisioning_pipeline.disconnect()
|
||||
|
||||
mock_mqtt_transport.disconnect.assert_called_once_with()
|
||||
|
@ -434,10 +462,12 @@ class TestDisconnect(object):
|
|||
|
||||
mock_provisioning_pipeline.connect()
|
||||
mock_mqtt_transport.on_mqtt_connected()
|
||||
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
|
||||
|
||||
mock_provisioning_pipeline.disconnect()
|
||||
mock_mqtt_transport.on_mqtt_disconnected()
|
||||
|
||||
mock_provisioning_pipeline.wait_for_on_disconnected_to_be_called()
|
||||
mock_provisioning_pipeline.on_disconnected.assert_called_once_with("disconnected")
|
||||
|
||||
|
||||
|
@ -450,6 +480,7 @@ class TestEnable(object):
|
|||
|
||||
mock_provisioning_pipeline.connect()
|
||||
mock_mqtt_transport.on_mqtt_connected()
|
||||
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
|
||||
mock_provisioning_pipeline.enable_responses()
|
||||
|
||||
assert mock_mqtt_transport.subscribe.call_count == 1
|
||||
|
@ -465,6 +496,7 @@ class TestDisable(object):
|
|||
|
||||
mock_provisioning_pipeline.connect()
|
||||
mock_mqtt_transport.on_mqtt_connected()
|
||||
mock_provisioning_pipeline.wait_for_on_connected_to_be_called()
|
||||
mock_provisioning_pipeline.disable_responses(None)
|
||||
|
||||
assert mock_mqtt_transport.unsubscribe.call_count == 1
|
||||
|
|
|
@ -50,6 +50,7 @@ class TestClientCreate(object):
|
|||
patch_set_sym_client = mocker.patch.object(
|
||||
pipeline_ops_provisioning, "SetSymmetricKeySecurityClientOperation"
|
||||
)
|
||||
patch_set_sym_client.callback = mocker.MagicMock()
|
||||
client = ProvisioningDeviceClient.create_from_symmetric_key(
|
||||
fake_provisioning_host, fake_symmetric_key, fake_registration_id, fake_id_scope
|
||||
)
|
||||
|
|
|
@ -59,7 +59,7 @@ jobs:
|
|||
inputs:
|
||||
testResultsFiles: '**/*-test-results.xml'
|
||||
testRunTitle: 'Python $(python.version)'
|
||||
condition: always()
|
||||
condition: always()
|
||||
|
||||
- task: PublishCodeCoverageResults@1
|
||||
inputs:
|
||||
|
|
Загрузка…
Ссылка в новой задаче