From f4c9c2d37a4859b7544d6481547a40f6e9b77621 Mon Sep 17 00:00:00 2001 From: hallvictoria <59299039+hallvictoria@users.noreply.github.com> Date: Wed, 14 Aug 2024 14:08:07 -0500 Subject: [PATCH] fix: added optional context param for tasks (#1523) * added optional context param for tasks * checks for 3.11 or lower * test fixes * lint & skipping tests * only one check needed * lint + comments * better tests * removed comment --------- Co-authored-by: Victoria Hall --- azure_functions_worker/dispatcher.py | 14 ++++-- .../create_task_with_context/function.json | 15 ++++++ .../create_task_with_context/main.py | 35 +++++++++++++ .../create_task_without_context/function.json | 15 ++++++ .../create_task_without_context/main.py | 20 ++++++++ .../http_functions_stein/function_app.py | 49 +++++++++++++++++++ tests/unittests/test_dispatcher.py | 39 ++++++++++++++- tests/unittests/test_http_functions.py | 15 ++++++ 8 files changed, 198 insertions(+), 4 deletions(-) create mode 100644 tests/unittests/http_functions/create_task_with_context/function.json create mode 100644 tests/unittests/http_functions/create_task_with_context/main.py create mode 100644 tests/unittests/http_functions/create_task_without_context/function.json create mode 100644 tests/unittests/http_functions/create_task_without_context/main.py diff --git a/azure_functions_worker/dispatcher.py b/azure_functions_worker/dispatcher.py index cd081583..820c328f 100644 --- a/azure_functions_worker/dispatcher.py +++ b/azure_functions_worker/dispatcher.py @@ -171,8 +171,11 @@ class Dispatcher(metaclass=DispatcherMeta): start_stream=protos.StartStream( worker_id=self.worker_id))) + # In Python 3.11+, constructing a task has an optional context + # parameter. Allow for this param to be passed to ContextEnabledTask self._loop.set_task_factory( - lambda loop, coro: ContextEnabledTask(coro, loop=loop)) + lambda loop, coro, context=None: ContextEnabledTask( + coro, loop=loop, context=context)) # Detach console logging before enabling GRPC channel logging logger.info('Detaching console logging.') @@ -1068,8 +1071,13 @@ class AsyncLoggingHandler(logging.Handler): class ContextEnabledTask(asyncio.Task): AZURE_INVOCATION_ID = '__azure_function_invocation_id__' - def __init__(self, coro, loop): - super().__init__(coro, loop=loop) + def __init__(self, coro, loop, context=None): + # The context param is only available for 3.11+. If + # not, it can't be sent in the init() call. + if sys.version_info.minor >= 11: + super().__init__(coro, loop=loop, context=context) + else: + super().__init__(coro, loop=loop) current_task = asyncio.current_task(loop) if current_task is not None: diff --git a/tests/unittests/http_functions/create_task_with_context/function.json b/tests/unittests/http_functions/create_task_with_context/function.json new file mode 100644 index 00000000..5d4d8285 --- /dev/null +++ b/tests/unittests/http_functions/create_task_with_context/function.json @@ -0,0 +1,15 @@ +{ + "scriptFile": "main.py", + "bindings": [ + { + "type": "httpTrigger", + "direction": "in", + "name": "req" + }, + { + "type": "http", + "direction": "out", + "name": "$return" + } + ] +} diff --git a/tests/unittests/http_functions/create_task_with_context/main.py b/tests/unittests/http_functions/create_task_with_context/main.py new file mode 100644 index 00000000..f603acd1 --- /dev/null +++ b/tests/unittests/http_functions/create_task_with_context/main.py @@ -0,0 +1,35 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +import asyncio +import contextvars + +import azure.functions + +num = contextvars.ContextVar('num') + + +async def count(name: str): + # The number of times the loop is executed + # depends on the val set in context + val = num.get() + for i in range(val): + await asyncio.sleep(0.5) + return f"Finished {name} in {val}" + + +async def main(req: azure.functions.HttpRequest): + # Create first task with context num = 5 + num.set(5) + first_ctx = contextvars.copy_context() + first_count_task = asyncio.create_task(count("Hello World"), context=first_ctx) + + # Create second task with context num = 10 + num.set(10) + second_ctx = contextvars.copy_context() + second_count_task = asyncio.create_task(count("Hello World"), context=second_ctx) + + # Execute tasks + first_count_val = await first_count_task + second_count_val = await second_count_task + + return f'{first_count_val + " | " + second_count_val}' diff --git a/tests/unittests/http_functions/create_task_without_context/function.json b/tests/unittests/http_functions/create_task_without_context/function.json new file mode 100644 index 00000000..5d4d8285 --- /dev/null +++ b/tests/unittests/http_functions/create_task_without_context/function.json @@ -0,0 +1,15 @@ +{ + "scriptFile": "main.py", + "bindings": [ + { + "type": "httpTrigger", + "direction": "in", + "name": "req" + }, + { + "type": "http", + "direction": "out", + "name": "$return" + } + ] +} diff --git a/tests/unittests/http_functions/create_task_without_context/main.py b/tests/unittests/http_functions/create_task_without_context/main.py new file mode 100644 index 00000000..c7ee21f7 --- /dev/null +++ b/tests/unittests/http_functions/create_task_without_context/main.py @@ -0,0 +1,20 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +import asyncio + +import azure.functions + + +async def count(name: str, num: int): + # The number of times the loop executes is decided by a + # user-defined param + for i in range(num): + await asyncio.sleep(0.5) + return f"Finished {name} in {num}" + + +async def main(req: azure.functions.HttpRequest): + # No context is being sent into asyncio.create_task + count_task = asyncio.create_task(count("Hello World", 5)) + count_val = await count_task + return f'{count_val}' diff --git a/tests/unittests/http_functions/http_functions_stein/function_app.py b/tests/unittests/http_functions/http_functions_stein/function_app.py index 4dd70303..112813de 100644 --- a/tests/unittests/http_functions/http_functions_stein/function_app.py +++ b/tests/unittests/http_functions/http_functions_stein/function_app.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. import asyncio +import contextvars import hashlib import json import logging @@ -14,6 +15,25 @@ app = func.FunctionApp() logger = logging.getLogger("my-function") +num = contextvars.ContextVar('num') + + +async def count_with_context(name: str): + # The number of times the loop is executed + # depends on the val set in context + val = num.get() + for i in range(val): + await asyncio.sleep(0.5) + return f"Finished {name} in {val}" + + +async def count_without_context(name: str, number: int): + # The number of times the loop executes is decided by a + # user-defined param + for i in range(number): + await asyncio.sleep(0.5) + return f"Finished {name} in {number}" + @app.route(route="return_str") def return_str(req: func.HttpRequest) -> str: @@ -404,3 +424,32 @@ def set_cookie_resp_header_empty( resp.headers.add("Set-Cookie", '') return resp + + +@app.route('create_task_with_context') +async def create_task_with_context(req: func.HttpRequest): + # Create first task with context num = 5 + num.set(5) + first_ctx = contextvars.copy_context() + first_count_task = asyncio.create_task( + count_with_context("Hello World"), context=first_ctx) + + # Create second task with context num = 10 + num.set(10) + second_ctx = contextvars.copy_context() + second_count_task = asyncio.create_task( + count_with_context("Hello World"), context=second_ctx) + + # Execute tasks + first_count_val = await first_count_task + second_count_val = await second_count_task + + return f'{first_count_val + " | " + second_count_val}' + + +@app.route('create_task_without_context') +async def create_task_without_context(req: func.HttpRequest): + # No context is being sent into asyncio.create_task + count_task = asyncio.create_task(count_without_context("Hello World", 5)) + count_val = await count_task + return f'{count_val}' diff --git a/tests/unittests/test_dispatcher.py b/tests/unittests/test_dispatcher.py index a924a734..32eca34b 100644 --- a/tests/unittests/test_dispatcher.py +++ b/tests/unittests/test_dispatcher.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import asyncio import collections as col +import contextvars import os import sys import unittest @@ -21,7 +22,7 @@ from azure_functions_worker.constants import ( PYTHON_THREADPOOL_THREAD_COUNT_MAX_37, PYTHON_THREADPOOL_THREAD_COUNT_MIN, ) -from azure_functions_worker.dispatcher import Dispatcher +from azure_functions_worker.dispatcher import Dispatcher, ContextEnabledTask from azure_functions_worker.version import VERSION SysVersionInfo = col.namedtuple("VersionInfo", ["major", "minor", "micro", @@ -989,3 +990,39 @@ class TestDispatcherIndexingInInit(unittest.TestCase): self.assertEqual( response.function_load_response.result.exception.message, "Exception: Mocked Exception") + + +class TestContextEnabledTask(unittest.TestCase): + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + + def test_init_with_context(self): + # Since ContextEnabledTask accepts the context param, + # no errors will be thrown here + num = contextvars.ContextVar('num') + num.set(5) + ctx = contextvars.copy_context() + exception_raised = False + try: + self.loop.set_task_factory( + lambda loop, coro, context=None: ContextEnabledTask( + coro, loop=loop, context=ctx)) + except TypeError: + exception_raised = True + self.assertFalse(exception_raised) + + async def test_init_without_context(self): + # If the context param is not defined, + # no errors will be thrown for backwards compatibility + exception_raised = False + try: + self.loop.set_task_factory( + lambda loop, coro: ContextEnabledTask( + coro, loop=loop)) + except TypeError: + exception_raised = True + self.assertFalse(exception_raised) diff --git a/tests/unittests/test_http_functions.py b/tests/unittests/test_http_functions.py index 61b6c433..03e0b580 100644 --- a/tests/unittests/test_http_functions.py +++ b/tests/unittests/test_http_functions.py @@ -446,6 +446,21 @@ class TestHttpFunctions(testutils.WebHostTestCase): # System logs should not exist in host_out self.assertNotIn('parallelly_log_system at disguised_logger', host_out) + @skipIf(sys.version_info.minor < 11, + "The context param is only available for 3.11+") + def test_create_task_with_context(self): + r = self.webhost.request('GET', 'create_task_with_context') + + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, 'Finished Hello World in 5' + ' | Finished Hello World in 10') + + def test_create_task_without_context(self): + r = self.webhost.request('GET', 'create_task_without_context') + + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, 'Finished Hello World in 5') + class TestHttpFunctionsStein(TestHttpFunctions):