Skip to content

Commit

Permalink
add Tests (#3)
Browse files Browse the repository at this point in the history
* Add CI

* fix

* fixes

* constraint python

* Fix mypy

* add some basic tests

* pre-commit

* use parametrized fixtures

* fix if / else var assignment lookup

* add support for ast.IfExp

* raise ValueError on chained comparison

* delete test.py

* enable nested_if_else_expr

* add readme

* add two if statement test

* use pytest as pixi test command

* move debug logging to test

* Use pytest.raises

Co-authored-by: Pavel Zwerschke <[email protected]>

* spelling

Co-authored-by: Pavel Zwerschke <[email protected]>

* Moved to #5

* reset readme

* add warlus_expr xfail test

* add signum_no_default test

* use NodeTransformer to Inline Expr

* split decorator

---------

Co-authored-by: Pavel Zwerschke <[email protected]>
  • Loading branch information
Bela Stoyan and pavelzw authored Jul 30, 2023
1 parent 36919ed commit 1b9ce8e
Show file tree
Hide file tree
Showing 6 changed files with 338 additions and 172 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
pixi.lock

__pycache__

.hypothesis/
3 changes: 2 additions & 1 deletion pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ platforms = ["linux-64", "osx-arm64", "osx-64", "win-64"]

[tasks]
"postinstall" = "pip install --no-build-isolation --no-deps --disable-pip-version-check -e ."
"test" = "python polarify/test.py"
"test" = "pytest"
"lint" = "pre-commit run --all"

[dependencies]
Expand All @@ -23,5 +23,6 @@ python = ">= 3.9"
"pytest" = "*"
"pytest-md" = "*"
"pytest-emoji" = "*"
"hypothesis" = "*"
# linting
"pre-commit" = "*"
13 changes: 8 additions & 5 deletions polarify/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
import inspect
from functools import wraps

from main import parse_body
from .main import parse_body


def polarify(func):
def transform_func_to_new_source(func) -> str:
source = inspect.getsource(func)
print(source)
tree = ast.parse(source)
func_def: ast.FunctionDef = tree.body[0] # type: ignore
expr = parse_body(func_def.body)
Expand All @@ -16,16 +15,20 @@ def polarify(func):
func_def.body = [ast.Return(expr)]
# TODO: make this prettier
func_def.decorator_list = []
func_def.name += "_polarified"

# Unparse the modified AST back into source code
new_func_code = ast.unparse(tree)
return ast.unparse(tree)


def polarify(func):
new_func_code = transform_func_to_new_source(func)
# Execute the new function code in the original function's globals
exec_globals = func.__globals__
exec(new_func_code, exec_globals)

# Get the new function from the globals
new_func = exec_globals[func.__name__]
new_func = exec_globals[func.__name__ + "_polarified"]

@wraps(func)
def wrapper(*args, **kwargs):
Expand Down
145 changes: 97 additions & 48 deletions polarify/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,80 @@
# TODO: make walrus throw ValueError
# TODO: Switch

assignments = dict[str, ast.expr]
Assignments = dict[str, ast.expr]


def inline_all(expr: ast.expr, assignments: assignments) -> ast.expr:
assignments = copy(assignments)
if isinstance(expr, ast.Name):
if expr.id not in assignments:
raise ValueError(f"Variable {expr.id} not defined")
return inline_all(assignments[expr.id], assignments)
elif isinstance(expr, ast.BinOp):
expr.left = inline_all(expr.left, assignments)
expr.right = inline_all(expr.right, assignments)
return expr
else:
return expr
def build_polars_when_then_otherwise(test: ast.expr, then: ast.expr, orelse: ast.expr):
when_node = ast.Call(
func=ast.Attribute(
value=ast.Name(id="pl", ctx=ast.Load()), attr="when", ctx=ast.Load()
),
args=[test],
keywords=[],
)

then_node = ast.Call(
func=ast.Attribute(value=when_node, attr="then", ctx=ast.Load()),
args=[then],
keywords=[],
)
final_node = ast.Call(
func=ast.Attribute(value=then_node, attr="otherwise", ctx=ast.Load()),
args=[orelse],
keywords=[],
)
return final_node


# ruff: noqa: N802
class InlineTransformer(ast.NodeTransformer):
def __init__(self, assignments: Assignments):
self.assignments = assignments

@classmethod
def inline_expr(cls, expr: ast.expr, assignments: Assignments) -> ast.expr:
return cls(assignments).visit(expr)

def visit_Name(self, node):
if node.id in self.assignments:
return self.visit(self.assignments[node.id])
else:
return node

def visit_BinOp(self, node):
node.left = self.visit(node.left)
node.right = self.visit(node.right)
return node

def visit_UnaryOp(self, node):
node.operand = self.visit(node.operand)
return node

def visit_Call(self, node):
node.args = [self.visit(arg) for arg in node.args]
node.keywords = [
ast.keyword(arg=k.arg, value=self.visit(k.value)) for k in node.keywords
]
return node

def visit_IfExp(self, node):
test = self.visit(node.test)
body = self.visit(node.body)
orelse = self.visit(node.orelse)
return build_polars_when_then_otherwise(test, body, orelse)

def visit_Constant(self, node):
return node

def visit_Compare(self, node):
if len(node.comparators) > 1:
raise ValueError("Polars can't handle chained comparisons")
node.left = self.visit(node.left)
node.comparators = [self.visit(c) for c in node.comparators]
return node

def generic_visit(self, node):
raise ValueError(f"Unsupported expression type: {type(node)}")


def is_returning_body(stmts: list[ast.stmt]) -> bool:
Expand All @@ -37,13 +96,13 @@ def is_returning_body(stmts: list[ast.stmt]) -> bool:
return False


def handle_assign(stmt: ast.Assign, assignments: assignments) -> assignments:
def handle_assign(stmt: ast.Assign, assignments: Assignments) -> Assignments:
assignments = copy(assignments)
diff_assignments = {}

for t in stmt.targets:
if isinstance(t, ast.Name):
new_value = inline_all(stmt.value, assignments)
new_value = InlineTransformer.inline_expr(stmt.value, assignments)
assignments[t.id] = new_value
diff_assignments[t.id] = new_value
elif isinstance(t, (ast.List, ast.Tuple)):
Expand All @@ -65,28 +124,40 @@ def handle_assign(stmt: ast.Assign, assignments: assignments) -> assignments:
return diff_assignments


def handle_non_returning_if(stmt: ast.If, assignments: assignments) -> assignments:
def handle_non_returning_if(stmt: ast.If, assignments: Assignments) -> Assignments:
assignments = copy(assignments)
assert not is_returning_body(stmt.orelse) and not is_returning_body(stmt.body)
test = inline_all(stmt.test, assignments)
test = InlineTransformer.inline_expr(stmt.test, assignments)

diff_assignments = {}
all_vars_changed_in_body = get_all_vars_changed_in_body(stmt.body, assignments)
all_vars_changed_in_orelse = get_all_vars_changed_in_body(stmt.orelse, assignments)

def updated_or_default_assignments(var: str, diff: Assignments) -> ast.expr:
if var in diff:
return diff[var]
elif var in assignments:
return assignments[var]
else:
raise ValueError(
f"Variable {var} has to be either defined in"
" all branches or have a previous defintion"
)

for var in all_vars_changed_in_body | all_vars_changed_in_orelse:
expr = build_polars_when_then_otherwise(
test,
all_vars_changed_in_body.get(var, assignments[var]),
all_vars_changed_in_orelse.get(var, assignments[var]),
updated_or_default_assignments(var, all_vars_changed_in_body),
updated_or_default_assignments(var, all_vars_changed_in_orelse),
)
assignments[var] = expr
diff_assignments[var] = expr
return diff_assignments


def get_all_vars_changed_in_body(
body: list[ast.stmt], assignments: assignments
) -> assignments:
body: list[ast.stmt], assignments: Assignments
) -> Assignments:
assignments = copy(assignments)
diff_assignments = {}

Expand All @@ -107,30 +178,8 @@ def get_all_vars_changed_in_body(
return diff_assignments


def build_polars_when_then_otherwise(test: ast.expr, then: ast.expr, orelse: ast.expr):
when_node = ast.Call(
func=ast.Attribute(
value=ast.Name(id="pl", ctx=ast.Load()), attr="when", ctx=ast.Load()
),
args=[test],
keywords=[],
)

then_node = ast.Call(
func=ast.Attribute(value=when_node, attr="then", ctx=ast.Load()),
args=[then],
keywords=[],
)
final_node = ast.Call(
func=ast.Attribute(value=then_node, attr="otherwise", ctx=ast.Load()),
args=[orelse],
keywords=[],
)
return final_node


def parse_body(
full_body: list[ast.stmt], assignments: Union[assignments, None] = None
full_body: list[ast.stmt], assignments: Union[Assignments, None] = None
) -> ast.expr:
if assignments is None:
assignments = {}
Expand All @@ -142,12 +191,12 @@ def parse_body(
assignments.update(handle_assign(stmt, assignments))
elif isinstance(stmt, ast.If):
if is_returning_body(stmt.body) and is_returning_body(stmt.orelse):
test = inline_all(stmt.test, assignments)
test = InlineTransformer.inline_expr(stmt.test, assignments)
body = parse_body(stmt.body, assignments)
orelse = parse_body(stmt.orelse, assignments)
return build_polars_when_then_otherwise(test, body, orelse)
elif is_returning_body(stmt.body):
test = inline_all(stmt.test, assignments)
test = InlineTransformer.inline_expr(stmt.test, assignments)
body = parse_body(stmt.body, assignments)
orelse_everything = parse_body(
stmt.orelse + full_body[i + 1 :], assignments
Expand All @@ -156,7 +205,7 @@ def parse_body(
elif is_returning_body(stmt.orelse):
test = ast.Call(
func=ast.Attribute(
value=inline_all(stmt.test, assignments),
value=InlineTransformer.inline_expr(stmt.test, assignments),
attr="not",
ctx=ast.Load(),
),
Expand All @@ -176,7 +225,7 @@ def parse_body(
if stmt.value is None:
raise ValueError("return needs a value")
# Handle return statements
return inline_all(stmt.value, assignments)
return InlineTransformer.inline_expr(stmt.value, assignments)
else:
raise ValueError(f"Unsupported statement type: {type(stmt)}")
raise ValueError("Missing return statement")
116 changes: 0 additions & 116 deletions polarify/test.py

This file was deleted.

Loading

0 comments on commit 1b9ce8e

Please sign in to comment.