fix: added optional context param for tasks (#1523)
* added optional context param for tasks * checks for 3.11 or lower * test fixes * lint & skipping tests * only one check needed * lint + comments * better tests * removed comment --------- Co-authored-by: Victoria Hall <victoria.hall@microsoft.com>
This commit is contained in:
Родитель
bbc683e4fb
Коммит
f4c9c2d37a
|
@ -171,8 +171,11 @@ class Dispatcher(metaclass=DispatcherMeta):
|
|||
start_stream=protos.StartStream(
|
||||
worker_id=self.worker_id)))
|
||||
|
||||
# In Python 3.11+, constructing a task has an optional context
|
||||
# parameter. Allow for this param to be passed to ContextEnabledTask
|
||||
self._loop.set_task_factory(
|
||||
lambda loop, coro: ContextEnabledTask(coro, loop=loop))
|
||||
lambda loop, coro, context=None: ContextEnabledTask(
|
||||
coro, loop=loop, context=context))
|
||||
|
||||
# Detach console logging before enabling GRPC channel logging
|
||||
logger.info('Detaching console logging.')
|
||||
|
@ -1068,8 +1071,13 @@ class AsyncLoggingHandler(logging.Handler):
|
|||
class ContextEnabledTask(asyncio.Task):
|
||||
AZURE_INVOCATION_ID = '__azure_function_invocation_id__'
|
||||
|
||||
def __init__(self, coro, loop):
|
||||
super().__init__(coro, loop=loop)
|
||||
def __init__(self, coro, loop, context=None):
|
||||
# The context param is only available for 3.11+. If
|
||||
# not, it can't be sent in the init() call.
|
||||
if sys.version_info.minor >= 11:
|
||||
super().__init__(coro, loop=loop, context=context)
|
||||
else:
|
||||
super().__init__(coro, loop=loop)
|
||||
|
||||
current_task = asyncio.current_task(loop)
|
||||
if current_task is not None:
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
{
|
||||
"scriptFile": "main.py",
|
||||
"bindings": [
|
||||
{
|
||||
"type": "httpTrigger",
|
||||
"direction": "in",
|
||||
"name": "req"
|
||||
},
|
||||
{
|
||||
"type": "http",
|
||||
"direction": "out",
|
||||
"name": "$return"
|
||||
}
|
||||
]
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
import asyncio
|
||||
import contextvars
|
||||
|
||||
import azure.functions
|
||||
|
||||
num = contextvars.ContextVar('num')
|
||||
|
||||
|
||||
async def count(name: str):
|
||||
# The number of times the loop is executed
|
||||
# depends on the val set in context
|
||||
val = num.get()
|
||||
for i in range(val):
|
||||
await asyncio.sleep(0.5)
|
||||
return f"Finished {name} in {val}"
|
||||
|
||||
|
||||
async def main(req: azure.functions.HttpRequest):
|
||||
# Create first task with context num = 5
|
||||
num.set(5)
|
||||
first_ctx = contextvars.copy_context()
|
||||
first_count_task = asyncio.create_task(count("Hello World"), context=first_ctx)
|
||||
|
||||
# Create second task with context num = 10
|
||||
num.set(10)
|
||||
second_ctx = contextvars.copy_context()
|
||||
second_count_task = asyncio.create_task(count("Hello World"), context=second_ctx)
|
||||
|
||||
# Execute tasks
|
||||
first_count_val = await first_count_task
|
||||
second_count_val = await second_count_task
|
||||
|
||||
return f'{first_count_val + " | " + second_count_val}'
|
|
@ -0,0 +1,15 @@
|
|||
{
|
||||
"scriptFile": "main.py",
|
||||
"bindings": [
|
||||
{
|
||||
"type": "httpTrigger",
|
||||
"direction": "in",
|
||||
"name": "req"
|
||||
},
|
||||
{
|
||||
"type": "http",
|
||||
"direction": "out",
|
||||
"name": "$return"
|
||||
}
|
||||
]
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
import asyncio
|
||||
|
||||
import azure.functions
|
||||
|
||||
|
||||
async def count(name: str, num: int):
|
||||
# The number of times the loop executes is decided by a
|
||||
# user-defined param
|
||||
for i in range(num):
|
||||
await asyncio.sleep(0.5)
|
||||
return f"Finished {name} in {num}"
|
||||
|
||||
|
||||
async def main(req: azure.functions.HttpRequest):
|
||||
# No context is being sent into asyncio.create_task
|
||||
count_task = asyncio.create_task(count("Hello World", 5))
|
||||
count_val = await count_task
|
||||
return f'{count_val}'
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
import asyncio
|
||||
import contextvars
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
|
@ -14,6 +15,25 @@ app = func.FunctionApp()
|
|||
|
||||
logger = logging.getLogger("my-function")
|
||||
|
||||
num = contextvars.ContextVar('num')
|
||||
|
||||
|
||||
async def count_with_context(name: str):
|
||||
# The number of times the loop is executed
|
||||
# depends on the val set in context
|
||||
val = num.get()
|
||||
for i in range(val):
|
||||
await asyncio.sleep(0.5)
|
||||
return f"Finished {name} in {val}"
|
||||
|
||||
|
||||
async def count_without_context(name: str, number: int):
|
||||
# The number of times the loop executes is decided by a
|
||||
# user-defined param
|
||||
for i in range(number):
|
||||
await asyncio.sleep(0.5)
|
||||
return f"Finished {name} in {number}"
|
||||
|
||||
|
||||
@app.route(route="return_str")
|
||||
def return_str(req: func.HttpRequest) -> str:
|
||||
|
@ -404,3 +424,32 @@ def set_cookie_resp_header_empty(
|
|||
resp.headers.add("Set-Cookie", '')
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
@app.route('create_task_with_context')
|
||||
async def create_task_with_context(req: func.HttpRequest):
|
||||
# Create first task with context num = 5
|
||||
num.set(5)
|
||||
first_ctx = contextvars.copy_context()
|
||||
first_count_task = asyncio.create_task(
|
||||
count_with_context("Hello World"), context=first_ctx)
|
||||
|
||||
# Create second task with context num = 10
|
||||
num.set(10)
|
||||
second_ctx = contextvars.copy_context()
|
||||
second_count_task = asyncio.create_task(
|
||||
count_with_context("Hello World"), context=second_ctx)
|
||||
|
||||
# Execute tasks
|
||||
first_count_val = await first_count_task
|
||||
second_count_val = await second_count_task
|
||||
|
||||
return f'{first_count_val + " | " + second_count_val}'
|
||||
|
||||
|
||||
@app.route('create_task_without_context')
|
||||
async def create_task_without_context(req: func.HttpRequest):
|
||||
# No context is being sent into asyncio.create_task
|
||||
count_task = asyncio.create_task(count_without_context("Hello World", 5))
|
||||
count_val = await count_task
|
||||
return f'{count_val}'
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
# Licensed under the MIT License.
|
||||
import asyncio
|
||||
import collections as col
|
||||
import contextvars
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
@ -21,7 +22,7 @@ from azure_functions_worker.constants import (
|
|||
PYTHON_THREADPOOL_THREAD_COUNT_MAX_37,
|
||||
PYTHON_THREADPOOL_THREAD_COUNT_MIN,
|
||||
)
|
||||
from azure_functions_worker.dispatcher import Dispatcher
|
||||
from azure_functions_worker.dispatcher import Dispatcher, ContextEnabledTask
|
||||
from azure_functions_worker.version import VERSION
|
||||
|
||||
SysVersionInfo = col.namedtuple("VersionInfo", ["major", "minor", "micro",
|
||||
|
@ -989,3 +990,39 @@ class TestDispatcherIndexingInInit(unittest.TestCase):
|
|||
self.assertEqual(
|
||||
response.function_load_response.result.exception.message,
|
||||
"Exception: Mocked Exception")
|
||||
|
||||
|
||||
class TestContextEnabledTask(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
|
||||
def tearDown(self):
|
||||
self.loop.close()
|
||||
|
||||
def test_init_with_context(self):
|
||||
# Since ContextEnabledTask accepts the context param,
|
||||
# no errors will be thrown here
|
||||
num = contextvars.ContextVar('num')
|
||||
num.set(5)
|
||||
ctx = contextvars.copy_context()
|
||||
exception_raised = False
|
||||
try:
|
||||
self.loop.set_task_factory(
|
||||
lambda loop, coro, context=None: ContextEnabledTask(
|
||||
coro, loop=loop, context=ctx))
|
||||
except TypeError:
|
||||
exception_raised = True
|
||||
self.assertFalse(exception_raised)
|
||||
|
||||
async def test_init_without_context(self):
|
||||
# If the context param is not defined,
|
||||
# no errors will be thrown for backwards compatibility
|
||||
exception_raised = False
|
||||
try:
|
||||
self.loop.set_task_factory(
|
||||
lambda loop, coro: ContextEnabledTask(
|
||||
coro, loop=loop))
|
||||
except TypeError:
|
||||
exception_raised = True
|
||||
self.assertFalse(exception_raised)
|
||||
|
|
|
@ -446,6 +446,21 @@ class TestHttpFunctions(testutils.WebHostTestCase):
|
|||
# System logs should not exist in host_out
|
||||
self.assertNotIn('parallelly_log_system at disguised_logger', host_out)
|
||||
|
||||
@skipIf(sys.version_info.minor < 11,
|
||||
"The context param is only available for 3.11+")
|
||||
def test_create_task_with_context(self):
|
||||
r = self.webhost.request('GET', 'create_task_with_context')
|
||||
|
||||
self.assertEqual(r.status_code, 200)
|
||||
self.assertEqual(r.text, 'Finished Hello World in 5'
|
||||
' | Finished Hello World in 10')
|
||||
|
||||
def test_create_task_without_context(self):
|
||||
r = self.webhost.request('GET', 'create_task_without_context')
|
||||
|
||||
self.assertEqual(r.status_code, 200)
|
||||
self.assertEqual(r.text, 'Finished Hello World in 5')
|
||||
|
||||
|
||||
class TestHttpFunctionsStein(TestHttpFunctions):
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче