diff --git a/ansible_rulebook/app.py b/ansible_rulebook/app.py index a9855330..45ae4b3b 100644 --- a/ansible_rulebook/app.py +++ b/ansible_rulebook/app.py @@ -130,6 +130,7 @@ async def run(parsed_args: argparse.Namespace) -> None: should_reload = await run_rulesets( event_log, + tasks, ruleset_queues, startup_args.variables, startup_args.inventory, diff --git a/ansible_rulebook/engine.py b/ansible_rulebook/engine.py index e4a54ee5..f3cefc73 100644 --- a/ansible_rulebook/engine.py +++ b/ansible_rulebook/engine.py @@ -18,7 +18,7 @@ import os import runpy from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from drools.dispatch import establish_async_channel, handle_async_messages from drools.ruleset import session_stats @@ -255,6 +255,7 @@ async def monitor_rulebook(rulebook_file): async def run_rulesets( event_log: asyncio.Queue, + source_tasks: Tuple[List[asyncio.Task]], ruleset_queues: List[RuleSetQueue], variables: Dict, inventory: str = "", @@ -299,6 +300,7 @@ async def run_rulesets( ruleset_runner = RuleSetRunner( event_log=event_log, ruleset_queue_plan=ruleset_queue_plan, + source_tasks=source_tasks, hosts_facts=hosts_facts, variables=variables, rule_set=rulesets[ruleset_queue_plan.ruleset.name], diff --git a/ansible_rulebook/rule_set_runner.py b/ansible_rulebook/rule_set_runner.py index 8d6c534c..5614daa6 100644 --- a/ansible_rulebook/rule_set_runner.py +++ b/ansible_rulebook/rule_set_runner.py @@ -18,7 +18,7 @@ import uuid from pprint import PrettyPrinter, pformat from types import MappingProxyType -from typing import Dict, List, Optional, Union, cast +from typing import Dict, List, Optional, Tuple, Union, cast import dpath from drools import ruleset as lang @@ -82,6 +82,7 @@ def __init__( self, event_log: asyncio.Queue, ruleset_queue_plan: EngineRuleSetQueuePlan, + source_tasks: Tuple[List[asyncio.Task]], hosts_facts, variables, rule_set, @@ -92,6 +93,7 @@ def __init__( self.action_loop_task = None self.event_log = event_log self.ruleset_queue_plan = ruleset_queue_plan + self.source_tasks = source_tasks self.name = ruleset_queue_plan.ruleset.name self.rule_set = rule_set self.hosts_facts = hosts_facts @@ -175,6 +177,8 @@ async def _handle_shutdown(self): self.name, str(self.shutdown), ) + for task in self.source_tasks: + task.cancel() if self.shutdown.kind == "now": logger.debug( "ruleset: %s has issued an immediate shutdown", self.name