fix: added optional context param for tasks (#1523)

* added optional context param for tasks

* checks for 3.11 or lower

* test fixes

* lint & skipping tests

* only one check needed

* lint + comments

* better tests

* removed comment

---------

Co-authored-by: Victoria Hall <victoria.hall@microsoft.com>
This commit is contained in:
hallvictoria 2024-08-14 14:08:07 -05:00 коммит произвёл GitHub
Родитель bbc683e4fb
Коммит f4c9c2d37a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
8 изменённых файлов: 198 добавлений и 4 удалений

Просмотреть файл

@ -171,8 +171,11 @@ class Dispatcher(metaclass=DispatcherMeta):
start_stream=protos.StartStream(
worker_id=self.worker_id)))
# In Python 3.11+, constructing a task has an optional context
# parameter. Allow for this param to be passed to ContextEnabledTask
self._loop.set_task_factory(
lambda loop, coro: ContextEnabledTask(coro, loop=loop))
lambda loop, coro, context=None: ContextEnabledTask(
coro, loop=loop, context=context))
# Detach console logging before enabling GRPC channel logging
logger.info('Detaching console logging.')
@ -1068,8 +1071,13 @@ class AsyncLoggingHandler(logging.Handler):
class ContextEnabledTask(asyncio.Task):
AZURE_INVOCATION_ID = '__azure_function_invocation_id__'
def __init__(self, coro, loop):
super().__init__(coro, loop=loop)
def __init__(self, coro, loop, context=None):
# The context param is only available for 3.11+. If
# not, it can't be sent in the init() call.
if sys.version_info.minor >= 11:
super().__init__(coro, loop=loop, context=context)
else:
super().__init__(coro, loop=loop)
current_task = asyncio.current_task(loop)
if current_task is not None:

Просмотреть файл

@ -0,0 +1,15 @@
{
"scriptFile": "main.py",
"bindings": [
{
"type": "httpTrigger",
"direction": "in",
"name": "req"
},
{
"type": "http",
"direction": "out",
"name": "$return"
}
]
}

Просмотреть файл

@ -0,0 +1,35 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import asyncio
import contextvars
import azure.functions
num = contextvars.ContextVar('num')
async def count(name: str):
# The number of times the loop is executed
# depends on the val set in context
val = num.get()
for i in range(val):
await asyncio.sleep(0.5)
return f"Finished {name} in {val}"
async def main(req: azure.functions.HttpRequest):
# Create first task with context num = 5
num.set(5)
first_ctx = contextvars.copy_context()
first_count_task = asyncio.create_task(count("Hello World"), context=first_ctx)
# Create second task with context num = 10
num.set(10)
second_ctx = contextvars.copy_context()
second_count_task = asyncio.create_task(count("Hello World"), context=second_ctx)
# Execute tasks
first_count_val = await first_count_task
second_count_val = await second_count_task
return f'{first_count_val + " | " + second_count_val}'

Просмотреть файл

@ -0,0 +1,15 @@
{
"scriptFile": "main.py",
"bindings": [
{
"type": "httpTrigger",
"direction": "in",
"name": "req"
},
{
"type": "http",
"direction": "out",
"name": "$return"
}
]
}

Просмотреть файл

@ -0,0 +1,20 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import asyncio
import azure.functions
async def count(name: str, num: int):
# The number of times the loop executes is decided by a
# user-defined param
for i in range(num):
await asyncio.sleep(0.5)
return f"Finished {name} in {num}"
async def main(req: azure.functions.HttpRequest):
# No context is being sent into asyncio.create_task
count_task = asyncio.create_task(count("Hello World", 5))
count_val = await count_task
return f'{count_val}'

Просмотреть файл

@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import asyncio
import contextvars
import hashlib
import json
import logging
@ -14,6 +15,25 @@ app = func.FunctionApp()
logger = logging.getLogger("my-function")
num = contextvars.ContextVar('num')
async def count_with_context(name: str):
# The number of times the loop is executed
# depends on the val set in context
val = num.get()
for i in range(val):
await asyncio.sleep(0.5)
return f"Finished {name} in {val}"
async def count_without_context(name: str, number: int):
# The number of times the loop executes is decided by a
# user-defined param
for i in range(number):
await asyncio.sleep(0.5)
return f"Finished {name} in {number}"
@app.route(route="return_str")
def return_str(req: func.HttpRequest) -> str:
@ -404,3 +424,32 @@ def set_cookie_resp_header_empty(
resp.headers.add("Set-Cookie", '')
return resp
@app.route('create_task_with_context')
async def create_task_with_context(req: func.HttpRequest):
# Create first task with context num = 5
num.set(5)
first_ctx = contextvars.copy_context()
first_count_task = asyncio.create_task(
count_with_context("Hello World"), context=first_ctx)
# Create second task with context num = 10
num.set(10)
second_ctx = contextvars.copy_context()
second_count_task = asyncio.create_task(
count_with_context("Hello World"), context=second_ctx)
# Execute tasks
first_count_val = await first_count_task
second_count_val = await second_count_task
return f'{first_count_val + " | " + second_count_val}'
@app.route('create_task_without_context')
async def create_task_without_context(req: func.HttpRequest):
# No context is being sent into asyncio.create_task
count_task = asyncio.create_task(count_without_context("Hello World", 5))
count_val = await count_task
return f'{count_val}'

Просмотреть файл

@ -2,6 +2,7 @@
# Licensed under the MIT License.
import asyncio
import collections as col
import contextvars
import os
import sys
import unittest
@ -21,7 +22,7 @@ from azure_functions_worker.constants import (
PYTHON_THREADPOOL_THREAD_COUNT_MAX_37,
PYTHON_THREADPOOL_THREAD_COUNT_MIN,
)
from azure_functions_worker.dispatcher import Dispatcher
from azure_functions_worker.dispatcher import Dispatcher, ContextEnabledTask
from azure_functions_worker.version import VERSION
SysVersionInfo = col.namedtuple("VersionInfo", ["major", "minor", "micro",
@ -989,3 +990,39 @@ class TestDispatcherIndexingInInit(unittest.TestCase):
self.assertEqual(
response.function_load_response.result.exception.message,
"Exception: Mocked Exception")
class TestContextEnabledTask(unittest.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
def tearDown(self):
self.loop.close()
def test_init_with_context(self):
# Since ContextEnabledTask accepts the context param,
# no errors will be thrown here
num = contextvars.ContextVar('num')
num.set(5)
ctx = contextvars.copy_context()
exception_raised = False
try:
self.loop.set_task_factory(
lambda loop, coro, context=None: ContextEnabledTask(
coro, loop=loop, context=ctx))
except TypeError:
exception_raised = True
self.assertFalse(exception_raised)
async def test_init_without_context(self):
# If the context param is not defined,
# no errors will be thrown for backwards compatibility
exception_raised = False
try:
self.loop.set_task_factory(
lambda loop, coro: ContextEnabledTask(
coro, loop=loop))
except TypeError:
exception_raised = True
self.assertFalse(exception_raised)

Просмотреть файл

@ -446,6 +446,21 @@ class TestHttpFunctions(testutils.WebHostTestCase):
# System logs should not exist in host_out
self.assertNotIn('parallelly_log_system at disguised_logger', host_out)
@skipIf(sys.version_info.minor < 11,
"The context param is only available for 3.11+")
def test_create_task_with_context(self):
r = self.webhost.request('GET', 'create_task_with_context')
self.assertEqual(r.status_code, 200)
self.assertEqual(r.text, 'Finished Hello World in 5'
' | Finished Hello World in 10')
def test_create_task_without_context(self):
r = self.webhost.request('GET', 'create_task_without_context')
self.assertEqual(r.status_code, 200)
self.assertEqual(r.text, 'Finished Hello World in 5')
class TestHttpFunctionsStein(TestHttpFunctions):