Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for compiling functions inline #69

Merged
merged 12 commits into from
Jan 7, 2025
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ print(Mem.cell1[63])

> **Note**: if you are using function calls, *pyndustric* will use `cell1` as a call stack, so you might not want to use `cell1` in that case to store data in it. (If you aren't calling any functions in your code, using `cell1` should be fine.)

> Alternatively, you can mark your functions with the `@inline` decorator, and it will compile them *inline*, so the function code gets copied to each function call. This is faster and means you don't need a memory cell, but if the function is used more than once, it will quickly bloat the generated code size.

Custom function definitions:

```python
Expand Down
7 changes: 6 additions & 1 deletion pyndustri.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC
from typing import Iterator
from typing import Iterator, TypeVar, Callable

class Link:
"""Represents a link."""
Expand Down Expand Up @@ -708,6 +708,11 @@ def sleep(secs: float):
Sleep for the given amount of seconds.
"""

T = TypeVar("T")

def inline(func: Callable[..., T]) -> Callable[..., T]:
"""Compile the function by copy/pasting the code into each function call"""

def flip(a: int) -> int:
"""Bitwise complement"""

Expand Down
100 changes: 62 additions & 38 deletions pyndustric/compiler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from re import sub

from .constants import *
from dataclasses import dataclass
from pathlib import Path
Expand Down Expand Up @@ -140,6 +142,8 @@ def __init__(self):
self._in_def = None # current function name
self._epilogue = None # current function's epilogue label
self._functions = {}
self._inline_functions = {}
self._in_inline_function = False
self._tmp_var_counter = 0
self._scope_start_label = (
[]
Expand Down Expand Up @@ -438,59 +442,72 @@ def visit_FunctionDef(self, node):
raise CompilerError(ERR_REDEF, node, a=node.name)

self._in_def = node.name
reg_ret = f"{REG_RET_COUNTER_PREFIX}{len(self._functions)}"

args = node.args
if any(
(
args.vararg,
args.kwonlyargs,
args.kw_defaults,
args.kwarg,
args.defaults,
)
):

decorators = [i.id for i in node.decorator_list]
# Check that all the decorators are valid
if any(decorator not in ALLOWED_DECORATORS for decorator in decorators):
# TODO: Add description specifiying that the decorator is the problem
raise CompilerError(ERR_INVALID_DEF, node, a=node.name)

if sys.version_info >= (3, 8):
if args.posonlyargs:
if "inline" in decorators:
self._inline_functions[node.name] = node.body
else:
reg_ret = f"{REG_RET_COUNTER_PREFIX}{len(self._functions)}"

args = node.args
if any(
(
args.vararg,
args.kwonlyargs,
args.kw_defaults,
args.kwarg,
args.defaults,
)
):
raise CompilerError(ERR_INVALID_DEF, node, a=node.name)

# TODO it's better to put functions at the end and not have to skip them as code, but jumps need fixing
end = _Label()
self.ins_append(_Jump(end, "always"))
if sys.version_info >= (3, 8):
if args.posonlyargs:
raise CompilerError(ERR_INVALID_DEF, node, a=node.name)

prologue = _Label()
self.ins_append(prologue)
self._functions[node.name] = Function(start=prologue, argc=len(args.args))
# TODO it's better to put functions at the end and not have to skip them as code, but jumps need fixing
end = _Label()
self.ins_append(_Jump(end, "always"))

self.ins_append(f"read {reg_ret} cell1 {REG_STACK}")
for arg in reversed(args.args):
self.ins_append(f"op sub {REG_STACK} {REG_STACK} 1")
self.ins_append(f"read {arg.arg} cell1 {REG_STACK}")
prologue = _Label()
self.ins_append(prologue)
self._functions[node.name] = Function(start=prologue, argc=len(args.args))

# This relies on the fact that there are no nested definitions.
# Set the epilogue now so that `visit_Return` can use this label.
self._epilogue = _Label()
self.ins_append(f"read {reg_ret} cell1 {REG_STACK}")
for arg in reversed(args.args):
self.ins_append(f"op sub {REG_STACK} {REG_STACK} 1")
self.ins_append(f"read {arg.arg} cell1 {REG_STACK}")

for subnode in node.body:
self.visit(subnode)
# This relies on the fact that there are no nested definitions.
# Set the epilogue now so that `visit_Return` can use this label.
self._epilogue = _Label()

self.ins_append(self._epilogue)
for subnode in node.body:
self.visit(subnode)

self.ins_append(self._epilogue)

# Add 1 to the return value to skip the jump that made the call.
self.ins_append(f"op add @counter {reg_ret} 1")
self.ins_append(end)
self._in_def = None
self._epilogue = None
# Add 1 to the return value to skip the jump that made the call.
self.ins_append(f"op add @counter {reg_ret} 1")
self.ins_append(end)
self._in_def = None
self._epilogue = None

def visit_Return(self, node):
if not self._epilogue:
if not self._epilogue and not self._in_inline_function:
raise CompilerError(INTERNAL_COMPILER_ERR, node, "return encountered with epilogue being unset")

val = self.as_value(node.value)
self.ins_append(f"set {REG_RET} {val}")
self.ins_append(_Jump(self._epilogue, "always"))
if self._in_inline_function:
self.ins_append(f"set {REG_RET} {val}")
else:
self.ins_append(f"set {REG_RET} {val}")
self.ins_append(_Jump(self._epilogue, "always"))

def visit_Expr(self, node):
call = node.value
Expand Down Expand Up @@ -1302,6 +1319,13 @@ def as_value(self, node, output: str = None):
self.ins_append(f"op {function} {output} {operands}")
return output

elif node.func.id in self._inline_functions:
body = self._inline_functions[node.func.id]
self._in_inline_function = True
for subnode in body:
self.visit(subnode)
self._in_inline_function = False
return REG_RET
else:
fn = self._functions.get(node.func.id)
if fn is None:
Expand Down
3 changes: 3 additions & 0 deletions pyndustric/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ERR_BAD_TUPLE_ASSIGN = "BadTupleError"
INTERNAL_COMPILER_ERR = "InternalCompilerError"


ERROR_DESCRIPTIONS = {
ERR_COMPLEX_ASSIGN: "cannot perform complex assignment `{unparsed}`",
ERR_COMPLEX_VALUE: "cannot evaluate complex value `{unparsed}`",
Expand Down Expand Up @@ -157,6 +158,8 @@
"payload_type": "@payloadType",
}

ALLOWED_DECORATORS = ("inline",)

REG_STACK = "__pyc_sp"
REG_RET = "__pyc_ret"
REG_RET_COUNTER_PREFIX = "__pyc_rc_"
Expand Down
82 changes: 82 additions & 0 deletions test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,88 @@ def source2():
assert masm == expected


def test_inline():
def source():
def f():
x = 1

f()

def source_inline():
@inline
def f():
x = 1

f()

expected = as_masm(
"""\
jump 5 always
read __pyc_rc_0 cell1 __pyc_sp
set x 1
op add @counter __pyc_rc_0 1
write @counter cell1 __pyc_sp
jump 2 always
set __pyc_tmp_1 __pyc_ret
"""
)
expected_inline = as_masm(
"""\
set x 1
"""
)

masm = pyndustric.Compiler().compile(source)
assert masm == expected

masm = pyndustric.Compiler().compile(source_inline)
assert masm == expected_inline


def test_inline_return():
def source():
def f():
x = 1
return x

rtn = f()

def source_inline():
@inline
def f():
x = 1
return x

rtn = f()

expected = as_masm(
"""\
jump 7 always
read __pyc_rc_0 cell1 __pyc_sp
set x 1
set __pyc_ret x
jump 6 always
op add @counter __pyc_rc_0 1
write @counter cell1 __pyc_sp
jump 2 always
set rtn __pyc_ret
"""
)
expected_inline = as_masm(
"""\
set x 1
set __pyc_ret x
set rtn __pyc_ret
"""
)

masm = pyndustric.Compiler().compile(source)
assert masm == expected

masm = pyndustric.Compiler().compile(source_inline)
assert masm == expected_inline


@masm_test
def test_assignments():
"""
Expand Down