Skip to content

Commit

Permalink
Merge branch 'master' into control_flow
Browse files Browse the repository at this point in the history
  • Loading branch information
heheda12345 committed Oct 18, 2023
2 parents 4abed81 + d6f9b69 commit a27558a
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 12 deletions.
4 changes: 2 additions & 2 deletions frontend/guard_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,8 +872,8 @@ def process_last_inst(self) -> None:
if self.state.num_new_refs == -1:
self.state.num_new_refs = get_value_stack_size(self.frame)
for i in range(self.state.num_new_refs):
self.state.object_refs.append(
get_value_stack_from_top(self.frame, i))
obj = get_value_stack_from_top(self.frame, i)
self.state.object_refs.append(obj)
self.state.num_new_refs = 0
for i, obj in enumerate(self.state.inplace_update_objs):
assert not isinstance(obj, torch.Tensor)
Expand Down
13 changes: 3 additions & 10 deletions frontend/object_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ def __init__(self) -> None:
self.objs_no_id = []

def add(self, var: Variable, value: Any) -> None:
if isinstance(value, bool):
self.objs_no_id.append(var)
elif id(value) in self.objs:
if id(value) in self.objs:
old_var = self.objs[id(value)]
if isinstance(old_var, AnyVar) and not isinstance(var, AnyVar):
self.objs[id(value)] = var
Expand Down Expand Up @@ -52,9 +50,7 @@ def get_all_with_id(self) -> list[Tuple[int, Variable]]:
return list(self.objs.items())

def get(self, value: Any, allow_unexist_const: bool = False) -> Variable:
if isinstance(value, bool):
return ScalarVar(value, True, False, None, [])
elif id(value) in self.objs:
if id(value) in self.objs:
return self.objs[id(value)]
elif allow_unexist_const:
if isinstance(value, get_args(CONST_TYPES)) or isinstance(
Expand All @@ -80,10 +76,7 @@ def get_or_make_var(self,
need_guard_check: bool,
fx_graph: Optional[FxGraph] = None,
extract_code_at_start: list[StorePos] = []) -> Variable:
if isinstance(value, bool):
return ScalarVar(value, True, need_guard_check, None,
extract_code_at_start)
elif id(value) in self.objs:
if id(value) in self.objs:
return self.objs[id(value)]
else:
return make_var_from_value(value, need_guard_check,
Expand Down
1 change: 1 addition & 0 deletions frontend/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
float: ScalarVar,
int: ScalarVar,
str: ScalarVar,
bool: ScalarVar,
torch.Tensor: TensorVar,
NullObject: NullVar,
type(None): NoneVar,
Expand Down

0 comments on commit a27558a

Please sign in to comment.