diff --git a/frontend/guard_tracker.py b/frontend/guard_tracker.py index 700646a93572..2f71b76b2c6f 100644 --- a/frontend/guard_tracker.py +++ b/frontend/guard_tracker.py @@ -23,7 +23,6 @@ from .pycode_generator import GraphFnCodegen, GuardFnCodegen from .fx_graph import FxGraph, get_frame_root, is_leaf_module, NodeArgs from .bytecode_analysis import livevars_analysis -from .variables.tuple_ import TupleVar from .variables.base import Variable @@ -35,6 +34,23 @@ class PartialVar: inplace_ref: Any = None # None if not inplace +class Bool_(): + value: bool + + def __init__(self, value: bool) -> None: + self.value = value + + #todo: overwirte xor, or, and, others uses super() + def __and__(self, operator: bool) -> "Bool_": + return Bool_(self.value and operator) + + def __or__(self, operator: bool) -> "Bool_": + return Bool_(self.value or operator) + + def __not__(self) -> "Bool_": + return Bool_(not self.value) + + class State: objects: ObjectTable start_pc: int @@ -639,8 +655,14 @@ def process_last_inst(self) -> None: if self.state.num_new_refs == -1: self.state.num_new_refs = get_value_stack_size(self.frame) for i in range(self.state.num_new_refs): - self.state.object_refs.append( - get_value_stack_from_top(self.frame, i)) + obj = get_value_stack_from_top(self.frame, i) + if isinstance(obj, bool): + new_bool = Bool_(obj) + var_bool = vs.ScalarVar(obj, True, False) + self.state.objects.update_by_id(var_bool, id(new_bool)) + self.state.object_refs.append(new_bool) + else: + self.state.object_refs.append(obj) self.state.num_new_refs = 0 for i, obj in enumerate(self.state.inplace_update_objs): assert not isinstance(obj, torch.Tensor) diff --git a/frontend/object_table.py b/frontend/object_table.py index a956844d3350..e2b2a9a6a2cf 100644 --- a/frontend/object_table.py +++ b/frontend/object_table.py @@ -17,9 +17,7 @@ def __init__(self) -> None: self.objs_no_id = [] def add(self, var: Variable, value: Any) -> None: - if isinstance(value, bool): - self.objs_no_id.append(var) - elif id(value) in self.objs: + if id(value) in self.objs: old_var = self.objs[id(value)] old_var.extract_code_at_start.extend(var.extract_code_at_start) old_var.need_guard_check |= var.need_guard_check @@ -47,9 +45,7 @@ def get_all_with_id(self) -> list[Tuple[int, Variable]]: return list(self.objs.items()) def get(self, value: Any, allow_unexist_const: bool = False) -> Variable: - if isinstance(value, bool): - return ScalarVar(value, True, False) - elif id(value) in self.objs: + if id(value) in self.objs: return self.objs[id(value)] elif allow_unexist_const: if isinstance(value, get_args(CONST_TYPES)) or isinstance( @@ -74,10 +70,7 @@ def get_or_make_var(self, need_guard_check: bool, fx_graph: Optional[FxGraph] = None, extract_code_at_start: list[StorePos] = []) -> Variable: - if isinstance(value, bool): - return ScalarVar(value, True, need_guard_check, - extract_code_at_start) - elif id(value) in self.objs: + if id(value) in self.objs: return self.objs[id(value)] else: return make_var_from_value(value, need_guard_check, diff --git a/frontend/variables/__init__.py b/frontend/variables/__init__.py index e71b9815fa38..0eb754bcd0e3 100644 --- a/frontend/variables/__init__.py +++ b/frontend/variables/__init__.py @@ -19,6 +19,7 @@ float: ScalarVar, int: ScalarVar, str: ScalarVar, + bool: ScalarVar, torch.Tensor: TensorVar, NullObject: NullVar, type(None): NoneVar,