diff --git a/django_dramatiq/apps.py b/django_dramatiq/apps.py index 39ed220..6ac41c5 100644 --- a/django_dramatiq/apps.py +++ b/django_dramatiq/apps.py @@ -36,13 +36,12 @@ class DjangoDramatiqConfig(AppConfig): name = "django_dramatiq" verbose_name = "Django Dramatiq" - @classmethod - def initialize(cls): + def ready(self): global RATE_LIMITER_BACKEND - dramatiq.set_encoder(cls.select_encoder()) + dramatiq.set_encoder(self.select_encoder()) - result_backend_settings = cls.result_backend_settings() + result_backend_settings = self.result_backend_settings() if result_backend_settings: result_backend_path = result_backend_settings.get("BACKEND", "dramatiq.results.backends.StubBackend") result_backend_class = import_string(result_backend_path) @@ -55,7 +54,7 @@ def initialize(cls): result_backend = None results_middleware = None - rate_limiter_backend_settings = cls.rate_limiter_backend_settings() + rate_limiter_backend_settings = self.rate_limiter_backend_settings() if rate_limiter_backend_settings: rate_limiter_backend_path = rate_limiter_backend_settings.get( "BACKEND", "dramatiq.rate_limits.backends.stub.StubBackend" @@ -64,11 +63,14 @@ def initialize(cls): rate_limiter_backend_options = rate_limiter_backend_settings.get("BACKEND_OPTIONS", {}) RATE_LIMITER_BACKEND = rate_limiter_backend_class(**rate_limiter_backend_options) - broker_settings = cls.broker_settings() + broker_settings = self.broker_settings() broker_path = broker_settings["BROKER"] broker_class = import_string(broker_path) broker_options = broker_settings.get("OPTIONS", {}) - middleware = [load_middleware(path) for path in broker_settings.get("MIDDLEWARE", [])] + middleware = [ + load_middleware(path, **self.get_middleware_kwargs(path)) + for path in broker_settings.get("MIDDLEWARE", []) + ] if result_backend is not None: middleware.append(results_middleware) @@ -84,6 +86,14 @@ def rate_limiter_backend(self): return RATE_LIMITER_BACKEND + def get_middleware_kwargs(self, path): + if isinstance(path, str): + middleware_path = path.rsplit('.', 1)[1].lower() + middleware_kwargs_method = "middleware_{}_kwargs".format(middleware_path) + if hasattr(self, middleware_kwargs_method): + return getattr(self, middleware_kwargs_method)() + return {} + @classmethod def broker_settings(cls): return getattr(settings, "DRAMATIQ_BROKER", DEFAULT_BROKER_SETTINGS) @@ -104,6 +114,3 @@ def tasks_database(cls): def select_encoder(cls): encoder = getattr(settings, "DRAMATIQ_ENCODER", DEFAULT_ENCODER) return import_string(encoder)() - - -DjangoDramatiqConfig.initialize() diff --git a/django_dramatiq/utils.py b/django_dramatiq/utils.py index 9024065..7da6653 100644 --- a/django_dramatiq/utils.py +++ b/django_dramatiq/utils.py @@ -1,7 +1,7 @@ from django.utils.module_loading import import_string -def load_middleware(path_or_obj): +def load_middleware(path_or_obj, **kwargs): if isinstance(path_or_obj, str): - return import_string(path_or_obj)() + return import_string(path_or_obj)(**kwargs) return path_or_obj