Skip to content

Commit

Permalink
Control flow (apache#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
heheda12345 authored Oct 18, 2023
2 parents d6f9b69 + a27558a commit cc4ef4d
Show file tree
Hide file tree
Showing 30 changed files with 1,351 additions and 387 deletions.
196 changes: 164 additions & 32 deletions frontend/bytecode_analysis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import dataclasses
import dis
import sys
import functools
from typing import Union, List
from collections import deque
from .instruction import Instruction

TERMINAL_OPCODES = {
Expand All @@ -17,6 +19,10 @@
TERMINAL_OPCODES.add(dis.opmap["JUMP_ABSOLUTE"])
JUMP_OPCODES = set(dis.hasjrel + dis.hasjabs)
JUMP_OPNAMES = {dis.opname[opcode] for opcode in JUMP_OPCODES}
MUST_JUMP_OPCODES = {
dis.opmap["JUMP_FORWARD"],
dis.opmap["JUMP_ABSOLUTE"],
}
HASLOCAL = set(dis.haslocal)
HASFREE = set(dis.hasfree)

Expand All @@ -43,39 +49,80 @@ class ReadsWrites:
def livevars_analysis(instructions: List[Instruction],
instruction: Instruction) -> set[str]:
indexof = get_indexof(instructions)
must = ReadsWrites(set(), set(), set())
may = ReadsWrites(set(), set(), set())

def walk(state: ReadsWrites, start: int) -> None:
if start in state.visited:
return
state.visited.add(start)

for i in range(start, len(instructions)):
inst = instructions[i]
if inst.opcode in HASLOCAL or inst.opcode in HASFREE:
if "LOAD" in inst.opname or "DELETE" in inst.opname:
assert isinstance(inst.argval, str)
if inst.argval not in must.writes:
state.reads.add(inst.argval)
elif "STORE" in inst.opname:
assert isinstance(inst.argval, str)
state.writes.add(inst.argval)
elif inst.opname == "MAKE_CELL":
pass
else:
raise NotImplementedError(f"unhandled {inst.opname}")
# if inst.exn_tab_entry:
# walk(may, indexof[inst.exn_tab_entry.target])
if inst.opcode in JUMP_OPCODES:
assert inst.target is not None
walk(may, indexof[inst.target])
state = may
if inst.opcode in TERMINAL_OPCODES:
return

walk(must, indexof[instruction])
return must.reads | may.reads
prev: dict[int, list[int]] = {}
succ: dict[int, list[int]] = {}
prev[0] = []
for i, inst in enumerate(instructions):
if inst.opcode not in TERMINAL_OPCODES:
prev[i + 1] = [i]
succ[i] = [i + 1]
else:
prev[i + 1] = []
succ[i] = []
for i, inst in enumerate(instructions):
if inst.opcode in JUMP_OPCODES:
assert inst.target is not None
target_pc = indexof[inst.target]
prev[target_pc].append(i)
succ[i].append(target_pc)

live_vars: dict[int, frozenset[str]] = {}

start_pc = indexof[instruction]
to_visit = deque([
pc for pc in range(len(instructions))
if instructions[pc].opcode in TERMINAL_OPCODES
])
in_progress: set[int] = set(to_visit)

def join_fn(a: frozenset[str], b: frozenset[str]) -> frozenset[str]:
return frozenset(a | b)

def gen_fn(
inst: Instruction,
incoming: frozenset[str]) -> tuple[frozenset[str], frozenset[str]]:
gen = set()
kill = set()
if inst.opcode in HASLOCAL or inst.opcode in HASFREE:
if "LOAD" in inst.opname or "DELETE" in inst.opname:
assert isinstance(inst.argval, str)
gen.add(inst.argval)
elif "STORE" in inst.opname:
assert isinstance(inst.argval, str)
kill.add(inst.argval)
elif inst.opname == "MAKE_CELL":
pass
else:
raise NotImplementedError(f"unhandled {inst.opname}")

return frozenset(gen), frozenset(kill)

while len(to_visit) > 0:
pc = to_visit.popleft()
in_progress.remove(pc)
if pc in live_vars:
before = hash(live_vars[pc])
else:
before = None
succs = [
live_vars[succ_pc] for succ_pc in succ[pc] if succ_pc in live_vars
]
if len(succs) > 0:
incoming = functools.reduce(join_fn, succs)
else:
incoming = frozenset()

gen, kill = gen_fn(instructions[pc], incoming)

out = (incoming - kill) | gen
live_vars[pc] = out
if hash(out) != before:
for prev_pc in prev[pc]:
if prev_pc not in in_progress:
to_visit.append(prev_pc)
in_progress.add(prev_pc)
return set(live_vars[start_pc])


stack_effect = dis.stack_effect
Expand Down Expand Up @@ -145,3 +192,88 @@ def stacksize_analysis(instructions: List[Instruction]) -> int:
assert low >= 0
assert isinstance(high, int) # not infinity
return high


def end_of_control_flow(instructions: List[Instruction], start_pc: int) -> int:
"""
Find the end of the control flow block starting at the given instruction.
"""
while instructions[start_pc].opname == 'EXTENDED_ARG':
start_pc += 1
assert instructions[start_pc].opcode in JUMP_OPCODES
assert instructions[start_pc].target is not None
indexof = get_indexof(instructions)
jump_only_opnames = ['JUMP_FORWARD', 'JUMP_ABSOLUTE']
jump_or_next_opnames = [
'POP_JUMP_IF_TRUE', 'POP_JUMP_IF_FALSE', 'JUMP_IF_NOT_EXC_MATCH',
'JUMP_IF_TRUE_OR_POP', 'JUMP_IF_FALSE_OR_POP', 'FOR_ITER'
]
jump_only_opcodes = [dis.opmap[opname] for opname in jump_only_opnames]
jump_or_next_opcodes = [
dis.opmap[opname] for opname in jump_or_next_opnames
]
return_value_opcode = dis.opmap['RETURN_VALUE']
possible_end_pcs = set()
for end_pc, inst in enumerate(instructions):
if end_pc == start_pc:
continue
inst = instructions[end_pc]
if not inst.is_jump_target:
continue
visited = set()
queue = deque([start_pc])
reach_end = False
while queue and not reach_end:
pc = queue.popleft()
inst = instructions[pc]
targets: list[int] = []
if inst.target is not None:
if inst.opcode in jump_only_opcodes:
targets = [indexof[inst.target]]
elif inst.opcode in jump_or_next_opcodes:
targets = [indexof[inst.target], pc + 1]
else:
raise NotImplementedError(f"unhandled {inst.opname}")
else:
targets = [pc + 1]
for target in targets:
if instructions[target].opcode == return_value_opcode:
reach_end = True
break
if target in visited:
continue
if target == end_pc:
continue
visited.add(target)
queue.append(target)
if not reach_end:
possible_end_pcs.add(end_pc)
visited = set()
dist: dict[int, int] = {start_pc: 0}
queue = deque([start_pc])
while queue:
pc = queue.popleft()
inst = instructions[pc]
if inst.opcode == return_value_opcode:
continue
targets = []
if inst.target is not None:
if inst.opcode in jump_only_opcodes:
targets = [indexof[inst.target]]
elif inst.opcode in jump_or_next_opcodes:
targets = [indexof[inst.target], pc + 1]
else:
raise NotImplementedError(f"unhandled {inst.opname}")
else:
targets = [pc + 1]
for target in targets:
if target in visited:
continue
visited.add(target)
dist[target] = dist[pc] + 1
queue.append(target)
min_dist = min([dist[end_pc] for end_pc in possible_end_pcs])
for end_pc in possible_end_pcs:
if dist[end_pc] == min_dist:
return end_pc
return -1
8 changes: 8 additions & 0 deletions frontend/c_api.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,12 @@ def get_code_map(frame: FrameType) -> 'ProcessedCode':


def is_bound_method(obj: Any, name: str) -> bool:
pass


def parse_rangeiterobject(obj: Any) -> Tuple[int, int, int, int]:
pass


def make_rangeiterobject(start: int, stop: int, step: int) -> Any:
pass
3 changes: 3 additions & 0 deletions frontend/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ def get_inst(self, lasti: int) -> Instruction:
def get_pc_by_inst(self, inst: Instruction) -> int:
return self.guarded_pc[inst]

def is_match(self, original_pc: int, guard_pc: int) -> bool:
return self.pc_guarded_to_origin[guard_pc] == original_pc

def get_dependence_of_stack_var(self, original_inst: Instruction,
stack_depth: int) -> list[Instruction]:
raise NotImplementedError
Expand Down
4 changes: 3 additions & 1 deletion frontend/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,6 @@ def reset() -> None:
from . import utils
utils.reset()
from . import fx_graph
fx_graph.reset()
fx_graph.reset()
from . import dynamic
dynamic.reset()
12 changes: 11 additions & 1 deletion frontend/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
from typing import Callable, Any, Union

backend: Union[str, Callable[..., Any]] = "inductor"
CONFIG = {
"backend": "inductor", # Union[str, Callable[..., Any]]
}


def set_config(key: str, value: Any) -> None:
CONFIG[key] = value


def get_config(key: str) -> Any:
return CONFIG[key]
60 changes: 60 additions & 0 deletions frontend/control_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import dataclasses
from typing import Any, Optional
import torch
from .store_pos import StorePos


@dataclasses.dataclass
class LoopPosMap:
input_only_pos: list[tuple[str, StorePos]]
joint_pos: list[tuple[str, StorePos]]
output_only_pos: list[tuple[str, StorePos]]


class LoopModule(torch.nn.Module): #type: ignore
body: torch.fx.GraphModule
num_read_only_param: int
num_iter: int

def __init__(self, body: torch.fx.GraphModule, num_read_only_param: int,
num_iter: int):
super(LoopModule, self).__init__()
self.body = body
self.num_read_only_param = num_read_only_param
self.num_iter = num_iter

# def forward(self, num_iter: Optional[int], cond: torch.Tensor, *values:
# Any) -> Any:
def forward(self, *values: Any) -> Any:
iter_num = 0
# assert cond.dtype == torch.bool
read_only = values[:self.num_read_only_param]
loop_carry = values[self.num_read_only_param:]
while iter_num < self.num_iter:
# and cond.item():
loop_carry = self.body(iter_num, *read_only, *loop_carry)
# cond, *loop_carry = self.body(iter_num, cond, *read_only,
# *loop_carry)
iter_num += 1
return loop_carry


class ControlFlowInfo:
start_pc: int
end_pc: int

def __init__(self, start_pc: int, end_pc: int) -> None:
self.start_pc = start_pc
self.end_pc = end_pc


class ForLoopInfo(ControlFlowInfo):
num_iter: int
cur_iter: int
pos_map: Optional[LoopPosMap]
inner_graph: Optional[torch.fx.Graph]

def __init__(self, start_pc: int, end_pc: int, num_iter: int) -> None:
super().__init__(start_pc, end_pc)
self.num_iter = num_iter
self.cur_iter = 0
2 changes: 2 additions & 0 deletions frontend/csrc/csrc.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,7 @@ struct StackEffect {
bool local_effect, global_effect;
};
StackEffect stack_effect(int opcode, int oparg, int jump);
PyObject *parse_rangeiterobject(PyObject *self, PyObject *args);
PyObject *make_rangeiterobject(PyObject *self, PyObject *args);

} // namespace frontend_csrc
4 changes: 4 additions & 0 deletions frontend/csrc/frame_evaluation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,10 @@ static PyMethodDef _methods[] = {
METH_VARARGS, NULL},
{"get_code_map", get_code_map, METH_VARARGS, NULL},
{"is_bound_method", is_bound_method, METH_VARARGS, NULL},
{"parse_rangeiterobject", frontend_csrc::parse_rangeiterobject,
METH_VARARGS, NULL},
{"make_rangeiterobject", frontend_csrc::make_rangeiterobject, METH_VARARGS,
NULL},
{NULL, NULL, 0, NULL}};

static struct PyModuleDef _module = {
Expand Down
41 changes: 41 additions & 0 deletions frontend/csrc/parse_types.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#include <Python.h>
#include <object.h>

namespace frontend_csrc {

typedef struct {
PyObject_HEAD long index;
long start;
long step;
long len;
} rangeiterobject;

PyObject *parse_rangeiterobject(PyObject *self, PyObject *args) {
PyObject *obj;
if (!PyArg_ParseTuple(args, "O", &obj)) {
return NULL;
}
if (Py_TYPE(obj) != &PyRangeIter_Type) {
PyErr_SetString(PyExc_TypeError, "Expected rangeiterobject");
return NULL;
}
rangeiterobject *robj = (rangeiterobject *)obj;
return PyTuple_Pack(
4, PyLong_FromLong(robj->index), PyLong_FromLong(robj->start),
PyLong_FromLong(robj->step), PyLong_FromLong(robj->len));
}

PyObject *make_rangeiterobject(PyObject *self, PyObject *args) {
long index, start, step, len;
if (!PyArg_ParseTuple(args, "llll", &index, &start, &step, &len)) {
return NULL;
}
rangeiterobject *robj = PyObject_New(rangeiterobject, &PyRangeIter_Type);
robj->index = index;
robj->start = start;
robj->step = step;
robj->len = len;
return (PyObject *)robj;
}

} // namespace frontend_csrc
Loading

0 comments on commit cc4ef4d

Please sign in to comment.