Skip to content

Commit

Permalink
add callable for Task arg. commands #101
Browse files Browse the repository at this point in the history
  • Loading branch information
leo-schick committed Nov 22, 2023
1 parent 4837cdf commit 7260baf
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 11 deletions.
47 changes: 36 additions & 11 deletions mara_pipelines/pipelines.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import pathlib
import re
from typing import Optional, Dict, Set, List, Tuple, Union
from typing import Optional, Dict, Set, List, Tuple, Union, Callable

from . import config

Expand Down Expand Up @@ -81,24 +81,49 @@ def html_doc_items(self) -> List[Tuple[str, str]]:


class Task(Node):
def __init__(self, id: str, description: str, commands: Optional[List[Command]] = None, max_retries: Optional[int] = None) -> None:
def __init__(self, id: str, description: str, commands: Optional[Union[Callable, List[Command]]] = None, max_retries: Optional[int] = None) -> None:
super().__init__(id, description)
self.commands = []
self.callable_commands = callable(commands)
self.max_retries = max_retries

for command in commands or []:
self.add_command(command)

def add_command(self, command: Command, prepend=False):
if self.callable_commands:
self._commands = None
self.__commands_callable = commands
else:
self._commands = []
self._add_commands(commands or [])

def _test_is_not_dynamic(self):
if self.callable_commands:
raise Exception('You cannot use add_command when the task is constructed with a callable commands function.')

@property
def commands(self) -> List:
if not self._commands:
self._commands = []
# execute the callable command function and cache the result
for command in self.__commands_callable() or []:
self._add_command(command)
return self._commands

def _add_command(self, command: Command, prepend=False):
if prepend:
self.commands.insert(0, command)
self._commands.insert(0, command)
else:
self.commands.append(command)
self._commands.append(command)
command.parent = self

def add_commands(self, commands: List[Command]):
def _add_commands(self, commands: List[Command]):
for command in commands:
self.add_command(command)
self._add_command(command)

def add_command(self, command: Command, prepend=False):
self._test_is_not_dynamic()
self._add_command(command, prepend=prepend)

def add_commands(self, commands: List[Command]):
self._test_is_not_dynamic()
self._add_commands(commands)

def run(self):
for command in self.commands:
Expand Down
4 changes: 4 additions & 0 deletions mara_pipelines/ui/node_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def __(pipeline: pipelines.Pipeline):
def __(task: pipelines.Task):
if not acl.current_user_has_permission(views.acl_resource):
return bootstrap.card(header_left='Commands', body=acl.inline_permission_denied_message())
elif task.callable_commands:
return bootstrap.card(
header_left='Commands',
body=f"""<span style="font-style:italic;color:#aaa"><span class="fa fa-lightbulb"> </span> {"... are defined dynamically during execution"}</span>""")
else:
commands_card = bootstrap.card(
header_left='Commands',
Expand Down
64 changes: 64 additions & 0 deletions tests/test_pipeline_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pytest
from typing import List

from mara_pipelines.pipelines import Task
from mara_pipelines.commands.python import RunFunction

class _PythonFuncTestResult:
has_run = False


def test_run_task():
"""
A simple test executing a task.
"""
test_result = _PythonFuncTestResult()

def python_test_function(result: _PythonFuncTestResult):
result.has_run = True # noqa: F841

assert not test_result.has_run

task = Task(
id='run_task',
description="Unit test test_run_task",
commands=[RunFunction(function=python_test_function, args=[test_result])])

assert not test_result.has_run

task.run()

assert test_result.has_run


def test_run_task_callable_commands():
"""
A simple test executing a task with callable commands
"""
import mara_pipelines.ui.node_page

test_result = _PythonFuncTestResult()

def python_test_function(result: _PythonFuncTestResult):
result.has_run = True # noqa: F841

def generate_command_list() -> List:
yield RunFunction(function=lambda t: python_test_function(t), args=[test_result])

assert not test_result.has_run

task = Task(
id='run_task_callable_commands',
description="Unit test test_run_task_callable_commands",
commands=generate_command_list)

assert not test_result.has_run

content = mara_pipelines.ui.node_page.node_content(task)
assert content

assert not test_result.has_run

task.run()

assert test_result.has_run

0 comments on commit 7260baf

Please sign in to comment.