Skip to content

Commit

Permalink
add bool
Browse files Browse the repository at this point in the history
  • Loading branch information
superDong1998 committed Oct 16, 2023
1 parent 7d8ccad commit f9af194
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 13 deletions.
28 changes: 25 additions & 3 deletions frontend/guard_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from .pycode_generator import GraphFnCodegen, GuardFnCodegen
from .fx_graph import FxGraph, get_frame_root, is_leaf_module, NodeArgs
from .bytecode_analysis import livevars_analysis
from .variables.tuple_ import TupleVar
from .variables.base import Variable


Expand All @@ -35,6 +34,23 @@ class PartialVar:
inplace_ref: Any = None # None if not inplace


class Bool_():
value: bool

def __init__(self, value: bool) -> None:
self.value = value

#todo: overwirte xor, or, and, others uses super()
def __and__(self, operator: bool) -> "Bool_":
return Bool_(self.value and operator)

def __or__(self, operator: bool) -> "Bool_":
return Bool_(self.value or operator)

def __not__(self) -> "Bool_":
return Bool_(not self.value)


class State:
objects: ObjectTable
start_pc: int
Expand Down Expand Up @@ -639,8 +655,14 @@ def process_last_inst(self) -> None:
if self.state.num_new_refs == -1:
self.state.num_new_refs = get_value_stack_size(self.frame)
for i in range(self.state.num_new_refs):
self.state.object_refs.append(
get_value_stack_from_top(self.frame, i))
obj = get_value_stack_from_top(self.frame, i)
if isinstance(obj, bool):
new_bool = Bool_(obj)
var_bool = vs.ScalarVar(obj, True, False)
self.state.objects.update_by_id(var_bool, id(new_bool))
self.state.object_refs.append(new_bool)
else:
self.state.object_refs.append(obj)
self.state.num_new_refs = 0
for i, obj in enumerate(self.state.inplace_update_objs):
assert not isinstance(obj, torch.Tensor)
Expand Down
13 changes: 3 additions & 10 deletions frontend/object_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ def __init__(self) -> None:
self.objs_no_id = []

def add(self, var: Variable, value: Any) -> None:
if isinstance(value, bool):
self.objs_no_id.append(var)
elif id(value) in self.objs:
if id(value) in self.objs:
old_var = self.objs[id(value)]
old_var.extract_code_at_start.extend(var.extract_code_at_start)
old_var.need_guard_check |= var.need_guard_check
Expand Down Expand Up @@ -47,9 +45,7 @@ def get_all_with_id(self) -> list[Tuple[int, Variable]]:
return list(self.objs.items())

def get(self, value: Any, allow_unexist_const: bool = False) -> Variable:
if isinstance(value, bool):
return ScalarVar(value, True, False)
elif id(value) in self.objs:
if id(value) in self.objs:
return self.objs[id(value)]
elif allow_unexist_const:
if isinstance(value, get_args(CONST_TYPES)) or isinstance(
Expand All @@ -74,10 +70,7 @@ def get_or_make_var(self,
need_guard_check: bool,
fx_graph: Optional[FxGraph] = None,
extract_code_at_start: list[StorePos] = []) -> Variable:
if isinstance(value, bool):
return ScalarVar(value, True, need_guard_check,
extract_code_at_start)
elif id(value) in self.objs:
if id(value) in self.objs:
return self.objs[id(value)]
else:
return make_var_from_value(value, need_guard_check,
Expand Down
1 change: 1 addition & 0 deletions frontend/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
float: ScalarVar,
int: ScalarVar,
str: ScalarVar,
bool: ScalarVar,
torch.Tensor: TensorVar,
NullObject: NullVar,
type(None): NoneVar,
Expand Down

0 comments on commit f9af194

Please sign in to comment.