diff --git a/.travis.yml b/.travis.yml index d759f70..c250207 100644 --- a/.travis.yml +++ b/.travis.yml @@ -13,10 +13,11 @@ env: install: - "pip install \"Django${DJANGO_SPEC}\" \"celery${CELERY_SPEC}\"" - - pip install . + - pip install -r requirements.txt -e . services: - postgresql + - rabbitmq addons: postgresql: "9.5" @@ -27,4 +28,4 @@ before_script: - psql -c "create user tenant_celery with password 'qwe123'" -U postgres - psql -c "alter role tenant_celery createdb" -U postgres -script: "cd test_app && python manage.py test tenant_schemas_celery" +script: "./run-tests" diff --git a/requirements.in b/requirements.in new file mode 100644 index 0000000..bf52995 --- /dev/null +++ b/requirements.in @@ -0,0 +1,4 @@ +pip-tools>=3.1.0 +pytest>=3.8.2 +pytest-django>=3.4.3 + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..9aacd88 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,16 @@ +# +# This file is autogenerated by pip-compile +# To update, run: +# +# pip-compile --output-file requirements.txt requirements.in +# +atomicwrites==1.2.1 # via pytest +attrs==18.2.0 # via pytest +click==7.0 # via pip-tools +more-itertools==4.3.0 # via pytest +pip-tools==3.1.0 +pluggy==0.7.1 # via pytest +py==1.6.0 # via pytest +pytest-django==3.4.3 +pytest==3.8.2 +six==1.11.0 # via more-itertools, pip-tools, pytest diff --git a/run-tests b/run-tests new file mode 100755 index 0000000..474ef9b --- /dev/null +++ b/run-tests @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 +import subprocess +import os + + +def main(): + new_environ = os.environ.copy() + new_environ["DJANGO_SETTINGS_MODULE"] = "test_app.settings" + cwd = os.path.abspath(os.path.join(os.path.dirname(__file__), "test_app")) + + celery_proc = subprocess.Popen( + ["celery", "worker", "-A", "tenant_schemas_celery.test_app:app", "-l", "INFO"], + env=new_environ.copy(), + cwd=cwd, + ) + try: + subprocess.check_call( + ["pytest", "../tenant_schemas_celery/tests.py"], + env=new_environ.copy(), + cwd=cwd, + ) + + finally: + celery_proc.terminate() + celery_proc.wait() + + +if __name__ == "__main__": + main() diff --git a/tenant_schemas_celery/compat.py b/tenant_schemas_celery/compat.py index 6b7d2fa..4fadedb 100644 --- a/tenant_schemas_celery/compat.py +++ b/tenant_schemas_celery/compat.py @@ -2,6 +2,6 @@ from django_tenants.test.cases import TenantTestCase from django_tenants.utils import get_public_schema_name, get_tenant_model -except ImportError as e: +except ImportError: from tenant_schemas.test.cases import TenantTestCase from tenant_schemas.utils import get_public_schema_name, get_tenant_model diff --git a/tenant_schemas_celery/tasks.py b/tenant_schemas_celery/tasks.py new file mode 100644 index 0000000..24ccf1b --- /dev/null +++ b/tenant_schemas_celery/tasks.py @@ -0,0 +1,14 @@ +@app.task +def update_task(model_id, name): + dummy = DummyModel.objects.get(pk=model_id) + dummy.name = name + dummy.save() + +@app.task(bind=True) +def update_retry_task(self, model_id, name): + connection.close() + if update_retry_task.request.retries: + return update_task(model_id, name) + + # Don't throw the Retry exception. + self.retry(throw=False) diff --git a/tenant_schemas_celery/test_app.py b/tenant_schemas_celery/test_app.py new file mode 100644 index 0000000..4a96f1e --- /dev/null +++ b/tenant_schemas_celery/test_app.py @@ -0,0 +1,15 @@ +try: + from .app import CeleryApp +except ImportError: + app = None +else: + app = CeleryApp('testapp') + + class CeleryConfig: + BROKER_URL = 'amqp://' + CELERY_RESULT_BACKEND = 'rpc://' + CELERY_RESULT_PERSISTENT = False + CELERY_ALWAYS_EAGER = False + + app.config_from_object(CeleryConfig) + app.autodiscover_tasks(['tenant_schemas_celery'], 'test_tasks') diff --git a/tenant_schemas_celery/test_tasks.py b/tenant_schemas_celery/test_tasks.py new file mode 100644 index 0000000..b4a1f79 --- /dev/null +++ b/tenant_schemas_celery/test_tasks.py @@ -0,0 +1,30 @@ +from __future__ import absolute_import + +from test_app.tenant.models import DummyModel + +from .test_app import app + + +class DoesNotExist(Exception): + pass + + +@app.task +def update_task(model_id, name): + try: + dummy = DummyModel.objects.get(pk=model_id) + + except DummyModel.DoesNotExist: + raise DoesNotExist() + + dummy.name = name + dummy.save() + + +@app.task(bind=True) +def update_retry_task(self, model_id, name): + if update_retry_task.request.retries: + return update_task(model_id, name) + + # Don't throw the Retry exception. + self.retry(countdown=0.1) diff --git a/tenant_schemas_celery/tests.py b/tenant_schemas_celery/tests.py index 2587f28..2056d29 100644 --- a/tenant_schemas_celery/tests.py +++ b/tenant_schemas_celery/tests.py @@ -1,4 +1,7 @@ -from unittest import skipIf +from __future__ import absolute_import + +import pytest +import time from django.db import connection from django.db.models.fields import FieldDoesNotExist @@ -6,121 +9,118 @@ from test_app.shared.models import Client from test_app.tenant.models import DummyModel -from .compat import get_public_schema_name, TenantTestCase +from .compat import get_public_schema_name +from .test_tasks import update_task, update_retry_task, DoesNotExist -try: - from .app import CeleryApp -except ImportError: - app = None -else: - app = CeleryApp('testapp') - class CeleryConfig: - CELERY_ALWAYS_EAGER = True - CELERY_EAGER_PROPAGATES_EXCEPTIONS = True +@pytest.fixture +def setup_tenant_test(transactional_db): + kwargs1 = {} + kwargs2 = {} - app.config_from_object(CeleryConfig) + data = {} - @app.task - def update_task(model_id, name): - dummy = DummyModel.objects.get(pk=model_id) - dummy.name = name - dummy.save() + try: + Client._meta.get_field('domain_url') + except FieldDoesNotExist: + pass + else: + kwargs1 = {'domain_url': 'test1.test.com'} + kwargs2 = {'domain_url': 'test2.test.com'} - @app.task - def update_retry_task(model_id, name): - if update_retry_task.request.retries: - return update_task(model_id, name) + tenant1 = data['tenant1'] = Client(name='test1', schema_name='test1', **kwargs1) + tenant1.save() - # Don't throw the Retry exception. - update_retry_task.retry(throw=False) + tenant2 = data['tenant2'] = Client(name='test2', schema_name='test2', **kwargs2) + tenant2.save() + connection.set_tenant(tenant1) + DummyModel.objects.all().delete() + data['dummy1'] = DummyModel.objects.create(name='test1') -@skipIf(app is None, 'Celery is not available.') -class CeleryTasksTests(TenantTestCase): - @classmethod - def setUpClass(cls): - pass + connection.set_tenant(tenant2) + DummyModel.objects.all().delete() + data['dummy2'] = DummyModel.objects.create(name='test2') - @classmethod - def tearDownClass(cls): - pass + connection.set_schema_to_public() - def setUp(self): - kwargs1 = {} - kwargs2 = {} + try: + yield data - try: - Client._meta.get_field('domain_url') - except FieldDoesNotExist: - pass - else: - kwargs1 = {'domain_url': 'test1.test.com'} - kwargs2 = {'domain_url': 'test2.test.com'} + finally: + connection.set_schema_to_public() - self.tenant1 = Client(name='test1', schema_name='test1', **kwargs1) - self.tenant1.save() - self.tenant2 = Client(name='test2', schema_name='test2', **kwargs2) - self.tenant2.save() +def test_should_update_model(setup_tenant_test): + dummy1, dummy2 = setup_tenant_test['dummy1'], setup_tenant_test['dummy2'] - connection.set_tenant(self.tenant1) - self.dummy1 = DummyModel.objects.create(name='test1') + # We should be in public schema where dummies don't exist. + for dummy in dummy1, dummy2: + # Test both async and local versions. + with pytest.raises(DoesNotExist): + update_task.apply_async(args=(dummy.pk, 'updated-name')).get() - connection.set_tenant(self.tenant2) - self.dummy2 = DummyModel.objects.create(name='test2') + with pytest.raises(DoesNotExist): + update_task.apply(args=(dummy.pk, 'updated-name')).get() - connection.set_schema_to_public() + connection.set_tenant(setup_tenant_test['tenant1']) + update_task.apply_async(args=(dummy1.pk, 'updated-name')).get() + assert connection.schema_name == setup_tenant_test['tenant1'].schema_name - def tearDown(self): - connection.set_schema_to_public() + # The task restores the schema from before running the task, so we are + # using the `tenant1` tenant now. + model_count = DummyModel.objects.filter(name='updated-name').count() + assert model_count == 1 - def test_basic_model_update(self): - # We should be in public schema where dummies don't exist. - for dummy in self.dummy1, self.dummy2: - # Test both async and local versions. - with self.assertRaises(DummyModel.DoesNotExist): - update_task.apply_async(args=(dummy.pk, 'updated-name')) + connection.set_tenant(setup_tenant_test['tenant2']) + model_count = DummyModel.objects.filter(name='updated-name').count() + assert model_count == 0 - with self.assertRaises(DummyModel.DoesNotExist): - update_task.apply(args=(dummy.pk, 'updated-name')) - connection.set_tenant(self.tenant1) - update_task.apply_async(args=(self.dummy1.pk, 'updated-name')) - self.assertEqual(connection.schema_name, self.tenant1.schema_name) +def test_task_retry(setup_tenant_test): + dummy1 = setup_tenant_test['dummy1'] - # The task restores the schema from before running the task, so we are - # using the `tenant1` tenant now. - model_count = DummyModel.objects.filter(name='updated-name').count() - self.assertEqual(model_count, 1) + # Schema name should persist through retry attempts. + connection.set_tenant(setup_tenant_test['tenant1']) + update_retry_task.apply_async(args=(dummy1.pk, 'updated-name')).get() - connection.set_tenant(self.tenant2) + for _ in range(19): model_count = DummyModel.objects.filter(name='updated-name').count() - self.assertEqual(model_count, 0) + try: + assert model_count == 1 - def test_task_retry(self): - # Schema name should persist through retry attempts. - connection.set_tenant(self.tenant1) - update_retry_task.apply_async(args=(self.dummy1.pk, 'updated-name')) + except AssertionError: + # Wait for the retried task to finish. + time.sleep(0.1) - model_count = DummyModel.objects.filter(name='updated-name').count() - self.assertEqual(model_count, 1) + else: + break + + model_count = DummyModel.objects.filter(name='updated-name').count() + assert model_count == 1 + + +def test_restoring_schema_name(setup_tenant_test): + dummy1 = setup_tenant_test['dummy1'] + dummy2 = setup_tenant_test['dummy2'] + + with tenant_context(setup_tenant_test['tenant1']): + update_task.apply_async(args=(dummy1.pk, 'updated-name')).get() + + assert connection.schema_name == get_public_schema_name() + + connection.set_tenant(setup_tenant_test['tenant1']) - def test_restoring_schema_name(self): - with tenant_context(self.tenant1): - update_task.apply_async(args=(self.dummy1.pk, 'updated-name')) - self.assertEqual(connection.schema_name, get_public_schema_name()) + with tenant_context(setup_tenant_test['tenant2']): + update_task.apply_async(args=(dummy2.pk, 'updated-name')).get() - connection.set_tenant(self.tenant1) + assert connection.schema_name == setup_tenant_test['tenant1'].schema_name - with tenant_context(self.tenant2): - update_task.apply_async(args=(self.dummy2.pk, 'updated-name')) - self.assertEqual(connection.schema_name, self.tenant1.schema_name) + connection.set_tenant(setup_tenant_test['tenant2']) - connection.set_tenant(self.tenant2) - # The model does not exist in the public schema. - with self.assertRaises(DummyModel.DoesNotExist): - with schema_context(get_public_schema_name()): - update_task.apply_async(args=(self.dummy2.pk, 'updated-name')) + # The model does not exist in the public schema. + with pytest.raises(DoesNotExist): + with schema_context(get_public_schema_name()): + update_task.apply_async(args=(dummy2.pk, 'updated-name')).get() - self.assertEqual(connection.schema_name, self.tenant2.schema_name) + assert connection.schema_name == setup_tenant_test['tenant2'].schema_name diff --git a/test_app/test_app/settings.py b/test_app/test_app/settings.py index 39d7873..b5241c6 100644 --- a/test_app/test_app/settings.py +++ b/test_app/test_app/settings.py @@ -91,6 +91,9 @@ 'NAME': 'tenant_celery', 'PASSWORD': 'qwe123', 'USER': 'tenant_celery', + 'TEST': { + 'NAME': 'tenant_celery', + } } }