diff --git a/example/contriblike/compile.py b/example/contriblike/compile.py new file mode 100644 index 00000000..5c171419 --- /dev/null +++ b/example/contriblike/compile.py @@ -0,0 +1,110 @@ +import ast +import sys +from pathlib import Path +from typing import Any, Sequence, Set + +import black + + +def get_tree(path: Path): + src = path.read_text() + return ast.parse(src) + + +def write_tree(tree, path): + new_src = ast.unparse(ast.fix_missing_locations(tree)) + new_src = black.format_str(new_src, mode=black.FileMode(line_length=120)) + path.write_text(new_src) + + +# turn op function defs into async function defs +class OpTransformer(ast.NodeTransformer): + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AsyncFunctionDef: + """any function in ops is an operator. make them async""" + self.generic_visit(node) + return ast.AsyncFunctionDef( + name=node.name, + args=node.args, + body=node.body, + decorator_list=node.decorator_list, + returns=node.returns, + type_comment=node.type_comment, + ) + + +class ImportTransformer(ast.NodeTransformer): + def __init__(self, *, module_allow_list: Sequence[str] = tuple()): + self.known_ops = set() + self.module_allow_list = frozenset( + {mn for mn in sys.stdlib_module_names if not mn.startswith("_")} | set(module_allow_list) + ) + + def visit_Import(self, node: ast.Import) -> ast.Import: + for alias in node.names: + if alias.name.split(".")[0] not in self.module_allow_list: + raise ValueError(f"Invalid 'import {alias.name}'.") + + return node + + def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom: + if node.module == "ops": + for alias_node in node.names: + op_name = alias_node.name + self.known_ops.add(op_name) + if alias_node.asname is not None: + raise ValueError( + f"Please import operator names without 'as', i.e. use '{op_name}' instead of '{alias_node.asname}'." + ) + node.module = "compiled_ops" + elif node.module.split(".")[0] in self.module_allow_list: + pass + else: + raise ValueError(f"Unsupported import from {node.module}") + + return node + + +class OpCallTransformer(ast.NodeTransformer): + def __init__(self, known_ops: Set[str]): + self.known_ops = known_ops + + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AsyncFunctionDef: + """any function in wfs is a workflow. make them async""" + self.generic_visit(node) + return ast.AsyncFunctionDef( + name=node.name, + args=node.args, + body=node.body, + decorator_list=node.decorator_list, + returns=node.returns, + type_comment=node.type_comment, + ) + + def visit_Call(self, node: ast.Call) -> Any: + """await any operator call""" + self.generic_visit(node) + if isinstance(node.func, ast.Name): + if node.func.id in self.known_ops: + return ast.Await(node) + else: + return node + elif isinstance(node.func, ast.Attribute): + return node # e.g. method call + else: + raise NotImplementedError(node.func) + + +ops_path = Path("ops.py") +ops_tree = get_tree(ops_path) +ops_tree = OpTransformer().visit(ops_tree) +write_tree(ops_tree, ops_path.with_name("compiled_" + ops_path.name)) + +wfs_path = Path("wfs.py") +wfs_tree = get_tree(wfs_path) + +allowed_module_names = [] # todo: add modules from appropriate env to allow-list +import_transformer = ImportTransformer(module_allow_list=allowed_module_names) +wfs_tree = import_transformer.visit(wfs_tree) +wfs_tree = OpCallTransformer(known_ops=import_transformer.known_ops).visit(wfs_tree) + +write_tree(wfs_tree, wfs_path.with_name("compiled_" + wfs_path.name)) diff --git a/example/contriblike/compiled_ops.py b/example/contriblike/compiled_ops.py new file mode 100644 index 00000000..ef2304f9 --- /dev/null +++ b/example/contriblike/compiled_ops.py @@ -0,0 +1,25 @@ +import shutil +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Tuple + + +async def my_op(a: int) -> str: + return f"{a:~^10}" + + +async def heavy_compute(p): + Path(p).write_text("done") + + +async def parallel_op(*, max_thread_workers: int) -> Tuple[str, str]: + srcs = ("src1.txt", "src2.txt") + dests = ("dest1.txt", "dest2.txt") + with ThreadPoolExecutor(max_workers=max_thread_workers) as e: + for (src, dest) in zip(srcs, dests): + e.submit(shutil.copy, src, dest) + return dests + + +if __name__ == "__main__": + print(my_op(5)) diff --git a/example/contriblike/compiled_wfs.py b/example/contriblike/compiled_wfs.py new file mode 100644 index 00000000..a1cec812 --- /dev/null +++ b/example/contriblike/compiled_wfs.py @@ -0,0 +1,20 @@ +from concurrent.futures import ProcessPoolExecutor +from typing import Dict, Tuple +from compiled_ops import heavy_compute, my_op, parallel_op + + +async def my_wf(a: int) -> Dict[str, str]: + m = await my_op(a) + return {"centered": m} + + +async def wf_with_p_ops(*, max_thread_workers: int, max_process_workers: int) -> Tuple[str, str]: + with ProcessPoolExecutor(max_workers=max_process_workers) as e: + e.submit(heavy_compute, "src1.txt") + e.submit(heavy_compute, "src2.txt") + return await parallel_op(max_thread_workers=max_thread_workers) + + +if __name__ == "__main__": + print(my_wf(5)) + print(wf_with_p_ops(max_thread_workers=2, max_process_workers=2)) diff --git a/example/contriblike/ops.py b/example/contriblike/ops.py new file mode 100644 index 00000000..2b606cf9 --- /dev/null +++ b/example/contriblike/ops.py @@ -0,0 +1,26 @@ +import shutil +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Tuple + + +def my_op(a: int) -> str: + return f"{a:~^10}" + + +async def heavy_compute(p): + Path(p).write_text("done") + + +def parallel_op(*, max_thread_workers: int) -> Tuple[str, str]: + srcs = ("src1.txt", "src2.txt") + dests = ("dest1.txt", "dest2.txt") + with ThreadPoolExecutor(max_workers=max_thread_workers) as e: + for src, dest in zip(srcs, dests): + e.submit(shutil.copy, src, dest) + + return dests + + +if __name__ == "__main__": + print(my_op(5)) diff --git a/example/contriblike/wfs.py b/example/contriblike/wfs.py new file mode 100644 index 00000000..1b56fe3a --- /dev/null +++ b/example/contriblike/wfs.py @@ -0,0 +1,22 @@ +from concurrent.futures import ProcessPoolExecutor +from typing import Dict, Tuple + +from ops import heavy_compute, my_op, parallel_op + + +def my_wf(a: int) -> Dict[str, str]: + m = my_op(a) + return {"centered": m} + + +def wf_with_p_ops(*, max_thread_workers: int, max_process_workers: int) -> Tuple[str, str]: + with ProcessPoolExecutor(max_workers=max_process_workers) as e: + e.submit(heavy_compute, "src1.txt") + e.submit(heavy_compute, "src2.txt") + + return parallel_op(max_thread_workers=max_thread_workers) + + +if __name__ == "__main__": + print(my_wf(5)) + print(wf_with_p_ops(max_thread_workers=2, max_process_workers=2))