Skip to content

Commit

Permalink
uow: update TaskOp; add TaskRevokeOp (#579)
Browse files Browse the repository at this point in the history
* uow: update TaskOp; add TaskRevokeOp

* tests: update tests
  • Loading branch information
yashlamba authored Jun 4, 2024
1 parent 939528b commit cf4460c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
25 changes: 24 additions & 1 deletion invenio_records_resources/services/uow.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def on_commit(self, uow):

from functools import wraps

from celery import current_app
from invenio_db import db

from ..tasks import send_change_notifications
Expand Down Expand Up @@ -261,10 +262,32 @@ def __init__(self, celery_task, *args, **kwargs):
self._celery_task = celery_task
self._args = args
self._kwargs = kwargs
self.celery_kwargs = {}

def on_post_commit(self, uow):
"""Run the post task operation."""
self._celery_task.delay(*self._args, **self._kwargs)
self._celery_task.apply_async(
args=self._args, kwargs=self._kwargs, **self.celery_kwargs
)

@classmethod
def for_async_apply(cls, celery_task, args=None, kwargs=None, **celery_kwargs):
"""Create TaskOp that supports apply_async args."""
temp = cls(celery_task, *(args or tuple()), **(kwargs or {}))
temp.celery_kwargs = celery_kwargs
return temp


class TaskRevokeOp(Operation):
"""A celery task stopping operation."""

def __init__(self, task_id: str) -> None:
"""Initialize the task operation."""
self.task_id = task_id

def on_post_commit(self, uow) -> None:
"""Run the revoke post commit."""
current_app.control.revoke(self.task_id, terminate=True)


class ChangeNotificationOp(Operation):
Expand Down
6 changes: 3 additions & 3 deletions tests/services/files/test_files_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ def test_image_meta_extraction(
file_service.set_file_content(identity_simple, recid, "image.png", image_fp)

# Commit (should send celery task)
assert not task.delay.called
assert not task.apply_async.called
file_service.commit_file(identity_simple, recid, "image.png")
assert task.delay.called
assert task.apply_async.called

# Call task manually
extract_file_metadata(*task.delay.call_args[0])
extract_file_metadata(*task.apply_async.call_args[1]["args"])

item = file_service.read_file_metadata(identity_simple, recid, "image.png")
assert item.data["metadata"] == {"width": 1000, "height": 1000}

0 comments on commit cf4460c

Please sign in to comment.