Replace post-request-task with transaction.on_commit on every delay/apply_async (#22377)

* Replace post-request-task with transaction.on_commit on every delay/apply_async

According to Django docs, transaction.on_commit should:
- only be called once the outer transaction.atomic() has been executed,
  guaranteeing that it works even if there are savepoints.
- be discarded on rollback

So it should be a drop-in replacement with the added benefit that it handles
tasks and commands doing transactions outside of the request-response cycle.
This commit is contained in:
Mathieu Pillard 2024-06-24 12:18:18 +02:00 коммит произвёл GitHub
Родитель 9d68cc1d24
Коммит c232800a2e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
14 изменённых файлов: 84 добавлений и 173 удалений

Просмотреть файл

@ -557,9 +557,6 @@ django-extensions==3.2.3 \
django-multidb-router==0.10 \
--hash=sha256:6802bbfce3a0dac343f93a290df12a9979aed99b350bc0148a7430353efae8b6 \
--hash=sha256:a10bf4b664465090a5f2e8543dd1491eaf8dde7c551270e546324454e7280964
django-post-request-task==0.5 \
--hash=sha256:26c03b5d06eb1705b2438bb719575fac4aae7f34c32837480202acad556edb3c \
--hash=sha256:91df3893c9551851cd10568ef3b2cf358bd87e8c65dce728c37196a8de34247c
django-recaptcha==4.0.0 \
--hash=sha256:0d912d5c7c009df4e47accd25029133d47a74342dbd2a8edc2877b6bffa971a3 \
--hash=sha256:5316438f97700c431d65351470d1255047e3f2cd9af0f2f13592b637dad9213e

Просмотреть файл

@ -81,6 +81,7 @@ CELERY_TASK_ALWAYS_EAGER = True
CELERY_IMPORTS += (
'olympia.amo.tests.test_celery',
'olympia.search.tests.test_commands',
'olympia.devhub.tests.test_tasks',
)
CELERY_TASK_ROUTES.update(

Просмотреть файл

@ -8,14 +8,16 @@ is directly being run/imported by Celery)
"""
import datetime
import functools
from django.core.cache import cache
from django.db import transaction
from celery import Celery, group
from celery.app.task import Task
from celery.signals import task_failure, task_postrun, task_prerun
from django_statsd.clients import statsd
from kombu import serialization
from post_request_task.task import PostRequestTask
import olympia.core.logger
@ -23,10 +25,11 @@ import olympia.core.logger
log = olympia.core.logger.getLogger('z.task')
class AMOTask(PostRequestTask):
"""A custom celery Task base class that inherits from `PostRequestTask`
to delay tasks and adds a special hack to still perform a serialization
roundtrip in eager mode, to mimic what happens in production in tests.
class AMOTask(Task):
"""A custom celery Task base class to always trigger tasks after the
current transaction has been committed, and also adds a special hack to
still perform a serialization roundtrip in eager mode, to mimic what
happens in production in tests.
The serialization is applied both to apply_async() and apply() to work
around the fact that celery groups have their own apply_async() method that
@ -50,13 +53,32 @@ class AMOTask(PostRequestTask):
args, kwargs = serialization.loads(data, content_type, content_encoding)
return args, kwargs
def original_apply_async(self, *args, **kwargs):
"""Alias for celery's original apply_async() method, allowing us to
trigger a task without waiting without waiting for the current
transaction to be committed. Use with caution."""
return super().apply_async(*args, **kwargs)
def apply_async(self, args=None, kwargs=None, **options):
if app.conf.task_always_eager:
args, kwargs = self._serialize_args_and_kwargs_for_eager_mode(
args=args, kwargs=kwargs, **options
)
return super().apply_async(args=args, kwargs=kwargs, **options)
# In eager mode, immediately call original apply async as we are
# using eager mode for tests, where no transaction is ever actually
# committed so transaction.on_commit() is never called.
self.original_apply_async(args=args, kwargs=kwargs, **options)
else:
# In normal mode, wait until the current transaction is committed
# to actually send the task.
transaction.on_commit(
functools.partial(
self.original_apply_async, args=args, kwargs=kwargs, **options
)
)
# We can't return anything meaningful if we're going through the
# on_commit path, so for consistency return None in all cases.
return None
def apply(self, args=None, kwargs=None, **options):
if app.conf.task_always_eager:

Просмотреть файл

@ -217,7 +217,9 @@ def remotesettings():
# a worker, and since workers have different network
# configuration than the Web head, we use a task to check
# the connectivity to the Remote Settings server.
# Since we want the result immediately, bypass django-post-request-task.
# Since we want the result immediately, use original_apply_async() to avoid
# waiting for the transaction django creates for the request like we
# usually do.
result = monitor_remote_settings.original_apply_async()
try:
status = result.get(timeout=settings.REMOTE_SETTINGS_CHECK_TIMEOUT_SECONDS)

Просмотреть файл

@ -5,11 +5,8 @@ from datetime import timedelta
from unittest import mock
from django.conf import settings
from django.core.signals import request_finished, request_started
from django.test.testcases import TransactionTestCase
from celery import group
from post_request_task.task import _discard_tasks, _stop_queuing_tasks
from olympia.amo.celery import app, create_chunked_tasks_signatures, task
from olympia.amo.tests import TestCase
@ -93,11 +90,17 @@ def sleeping_task(time_to_sleep):
class TestCeleryWorker(TestCase):
def trigger_fake_task(self, task_func):
# We use original_apply_async to bypass our own delay()/apply_async()
# which is only really triggered when the transaction is committed
# and returns None instead of an AsyncResult we can grab the id from.
result = task_func.original_apply_async()
result.get()
return result
@mock.patch('olympia.amo.celery.cache')
def test_start_task_timer(self, celery_cache):
result = fake_task_with_result.delay()
result.get()
result = self.trigger_fake_task(fake_task_with_result)
assert celery_cache.set.called
assert celery_cache.set.call_args[0][0] == f'task_start_time.{result.id}'
@ -108,8 +111,7 @@ class TestCeleryWorker(TestCase):
task_start = utc_millesecs_from_epoch(minute_ago)
celery_cache.get.return_value = task_start
result = fake_task_with_result.delay()
result.get()
result = self.trigger_fake_task(fake_task_with_result)
approx_run_time = utc_millesecs_from_epoch() - task_start
assert (
@ -130,47 +132,5 @@ class TestCeleryWorker(TestCase):
@mock.patch('olympia.amo.celery.statsd')
def test_handle_cache_miss_for_stats(self, celery_cache, celery_statsd):
celery_cache.get.return_value = None # cache miss
fake_task.delay()
self.trigger_fake_task(fake_task)
assert not celery_statsd.timing.called
class TestTaskQueued(TransactionTestCase):
"""Test that tasks are queued and only triggered when a request finishes.
Tests our integration with django-post-request-task.
"""
def setUp(self):
super().setUp()
fake_task_func.reset_mock()
_discard_tasks()
def tearDown(self):
super().tearDown()
fake_task_func.reset_mock()
_discard_tasks()
_stop_queuing_tasks()
def test_not_queued_outside_request_response_cycle(self):
fake_task.delay()
assert fake_task_func.call_count == 1
def test_queued_inside_request_response_cycle(self):
request_started.send(sender=self)
fake_task.delay()
assert fake_task_func.call_count == 0
request_finished.send_robust(sender=self)
assert fake_task_func.call_count == 1
def test_no_dedupe_outside_request_response_cycle(self):
fake_task.delay()
fake_task.delay()
assert fake_task_func.call_count == 2
def test_dedupe_inside_request_response_cycle(self):
request_started.send(sender=self)
fake_task.delay()
fake_task.delay()
assert fake_task_func.call_count == 0
request_finished.send_robust(sender=self)
assert fake_task_func.call_count == 1

Просмотреть файл

@ -253,7 +253,8 @@ def send_mail(
kwargs.update(options)
# Email subject *must not* contain newlines
args = (list(recipients), ' '.join(subject.splitlines()), message)
return send_email.delay(*args, **kwargs)
send_email.delay(*args, **kwargs)
return True
if white_list:
if perm_setting:

Просмотреть файл

@ -72,10 +72,3 @@ class CoreConfig(AppConfig):
# Ignore Python warnings unless we're running in debug mode.
if not settings.DEBUG:
warnings.simplefilter('ignore')
self.enable_post_request_task()
def enable_post_request_task(self):
"""Import post_request_task so that it can listen to `request_started`
signal before the first request is handled."""
import post_request_task.task # noqa

Просмотреть файл

@ -73,9 +73,10 @@ def validate(file_, *, final_task=None, theme_specific=False):
if task_id:
return AsyncResult(task_id)
else:
result = validator.get_task().delay()
cache.set(validator.cache_key, result.task_id, 5 * 60)
return result
task = validator.get_task()
task_id = task.freeze().id
cache.set(validator.cache_key, task_id, 5 * 60)
return task.delay()
def validate_and_submit(*, addon, upload, client_info, theme_specific=False):

Просмотреть файл

@ -103,7 +103,7 @@ class TestAddonsLinterListed(UploadMixin, TestCase):
@mock.patch.object(utils.Validator, 'get_task')
def test_run_once_per_file(self, get_task_mock):
"""Tests that only a single validation task is run for a given file."""
get_task_mock.return_value.delay.return_value = mock.Mock(task_id='42')
get_task_mock.return_value.freeze.return_value = mock.Mock(id='42')
assert isinstance(tasks.validate(self.file), mock.Mock)
assert get_task_mock.return_value.delay.call_count == 1
@ -119,7 +119,7 @@ class TestAddonsLinterListed(UploadMixin, TestCase):
def test_run_once_file_upload(self, get_task_mock):
"""Tests that only a single validation task is run for a given file
upload."""
get_task_mock.return_value.delay.return_value = mock.Mock(task_id='42')
get_task_mock.return_value.freeze.return_value = mock.Mock(id='42')
assert isinstance(tasks.validate(self.file_upload), mock.Mock)
assert get_task_mock.return_value.delay.call_count == 1

Просмотреть файл

@ -6,13 +6,6 @@ from django.conf import settings
from django.db import transaction
from django.utils import translation
from post_request_task.task import (
_discard_tasks,
_send_tasks_and_stop_queuing,
_start_queuing_tasks,
_stop_queuing_tasks,
)
import olympia.core.logger
from olympia import amo
from olympia.activity.models import ActivityLog
@ -186,21 +179,7 @@ def bump_addon_version(old_version):
)
parsed_data['approval_notes'] = old_version.approval_notes
# Discard any existing celery tasks that may have been queued before:
# If there are any left at this point, it means the transaction from
# the previous loop iteration was not committed and we shouldn't
# trigger the corresponding tasks.
_discard_tasks()
# Queue celery tasks for this transaction, avoiding triggering them too
# soon before the transaction has been committed...
# (useful for things like theme preview generation)
_start_queuing_tasks()
with transaction.atomic():
# ...and release the queued tasks to celery once transaction
# is committed.
transaction.on_commit(_send_tasks_and_stop_queuing)
# Create a version object out of the FileUpload + parsed data.
new_version = Version.from_upload(
upload,
@ -228,13 +207,6 @@ def bump_addon_version(old_version):
log.exception(f'Failed re-signing file {old_file_obj.pk}', exc_info=True)
# Next loop iteration will clear the task queue.
return
finally:
# Stop post request task queue before moving on (useful in tests to
# leave a fresh state for the next test. Note that we don't want to
# send or clear queued tasks (they may belong to a transaction that
# has been rolled back, or they may not have been processed by the
# on commit handler yet).
_stop_queuing_tasks()
# Now notify the developers of that add-on. Any exception should have
# caused an early return before reaching this point.

Просмотреть файл

@ -1013,12 +1013,6 @@ LOGGING = {
'propagate': False,
},
'parso': {'handlers': ['null'], 'level': logging.INFO, 'propagate': False},
'post_request_task': {
'handlers': ['mozlog'],
# Ignore INFO or DEBUG from post-request-task, it logs too much.
'level': logging.WARNING,
'propagate': False,
},
'sentry_sdk': {
'handlers': ['mozlog'],
'level': logging.WARNING,

Просмотреть файл

@ -6,12 +6,6 @@ from django.db import transaction
import waffle
from django_statsd.clients import statsd
from post_request_task.task import (
_discard_tasks,
_send_tasks_and_stop_queuing,
_start_queuing_tasks,
_stop_queuing_tasks,
)
import olympia.core.logger
from olympia import amo
@ -100,21 +94,8 @@ class Command(BaseCommand):
# our own.
set_reviewing_cache(version.addon.pk, settings.TASK_USER_ID)
# Discard any existing celery tasks that may have been queued before:
# If there are any left at this point, it means the transaction from
# the previous loop iteration was not committed and we shouldn't
# trigger the corresponding tasks.
_discard_tasks()
# Queue celery tasks for this version, avoiding triggering them too
# soon...
_start_queuing_tasks()
try:
with transaction.atomic():
# ...and release the queued tasks to celery once transaction
# is committed.
transaction.on_commit(_send_tasks_and_stop_queuing)
log.info(
'Processing %s version %s...',
str(version.addon.name),
@ -182,12 +163,6 @@ class Command(BaseCommand):
# Always clear our own lock no matter what happens (but only ours).
if not already_locked:
clear_reviewing_cache(version.addon.pk)
# Stop post request task queue before moving on (useful in tests to
# leave a fresh state for the next test).
_stop_queuing_tasks()
# We also clear any stray queued tasks. We're out of the @atomic block so
# the tranaction has either been rolled back or commited.
_discard_tasks()
@statsd.timer('reviewers.auto_approve.approve')
def approve(self, version):

Просмотреть файл

@ -4,13 +4,6 @@ from django.conf import settings
from django.core.management.base import BaseCommand
from django.db import transaction
from post_request_task.task import (
_discard_tasks,
_send_tasks_and_stop_queuing,
_start_queuing_tasks,
_stop_queuing_tasks,
)
import olympia.core.logger
from olympia import amo
from olympia.abuse.models import CinderJob
@ -103,20 +96,8 @@ class Command(BaseCommand):
)
def process_addon(self, *, addon, now):
# Discard any existing celery tasks that may have been queued before:
# If there are any left at this point, it means the transaction from
# the previous loop iteration was not committed and we shouldn't
# trigger the corresponding tasks.
_discard_tasks()
# Queue celery tasks for this version, avoiding triggering them too
# soon...
_start_queuing_tasks()
try:
with transaction.atomic():
# ...and release the queued tasks to celery once transaction
# is committed.
transaction.on_commit(_send_tasks_and_stop_queuing)
latest_version = addon.find_latest_version(channel=amo.CHANNEL_LISTED)
if (
latest_version
@ -158,12 +139,6 @@ class Command(BaseCommand):
finally:
# Always clear our lock no matter what happens.
clear_reviewing_cache(addon.pk)
# Stop post request task queue before moving on (useful in tests to
# leave a fresh state for the next test).
_stop_queuing_tasks()
# We also clear any stray queued tasks. We're out of the @atomic block so
# the tranaction has either been rolled back or commited.
_discard_tasks()
@use_primary_db
def handle(self, *args, **kwargs):

Просмотреть файл

@ -1,4 +1,7 @@
import contextlib
from django.core.management.base import BaseCommand
from django.db.transaction import atomic
import olympia.core.logger
from olympia.zadmin.tasks import celery_error
@ -18,6 +21,12 @@ class Command(BaseCommand):
action='store_true',
help='Raise the error in a celery task',
)
parser.add_argument(
'--atomic',
default=False,
action='store_true',
help='Raise the error in an atomic block',
)
parser.add_argument(
'--log',
default=False,
@ -26,18 +35,27 @@ class Command(BaseCommand):
)
def handle(self, *args, **options):
if options.get('celery'):
celery_error.delay(capture_and_log=options.get('log', False))
print(
'A RuntimeError exception was raised from a celery task. '
'Check the logs!'
)
if options.get('atomic'):
print('Inside an atomic block...')
context_manager = atomic
else:
print('About to raise an exception in management command')
try:
raise RuntimeError('This is an exception from a management command')
except Exception as exception:
if options.get('log', False):
log.exception('Capturing exception as a log', exc_info=exception)
else:
raise exception
context_manager = contextlib.nullcontext
with context_manager():
if options.get('celery'):
celery_error.delay(capture_and_log=options.get('log', False))
print(
'A RuntimeError exception was raised from a celery task. '
'Check the logs!'
)
else:
print('About to raise an exception in management command')
try:
raise RuntimeError('This is an exception from a management command')
except Exception as exception:
if options.get('log', False):
log.exception(
'Capturing exception as a log', exc_info=exception
)
else:
raise exception