Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

[Compat][3.11] support for-loop #371

Merged
merged 21 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
calc_stack_effect,
get_instructions,
)
from ..instruction_utils.opcode_info import JumpDirection, PopJumpCond
from .dispatch_functions import (
operator_BAD,
operator_exception_match,
Expand Down Expand Up @@ -556,9 +557,9 @@ def __init__(self, code: types.CodeType, graph: FunctionGraph):
self._cells = {} # position to put cells
self._lasti = 0 # idx of instruction list
self._code = code
self._current_line: int = -1
self._instructions = get_instructions(self._code)
self._graph = graph
self._current_line: int = -1
self.new_code: types.CodeType | None = None
self.guard_fn = None
self._name = "Executor"
Expand Down Expand Up @@ -793,7 +794,11 @@ def DUP_TOP_TWO(self, instr: Instruction):
for ref in self.stack.peek[:2]:
self.stack.push(ref)

def _rot_top_n(self, n):
def ROT_N(self, instr: Instruction):
assert instr.argval is not None
self._rot_top_n(instr.argval)

def _rot_top_n(self, n: int):
# a1 a2 a3 ... an <- TOS
# the stack changes to
# an a1 a2 a3 an-1 <- TOS
Expand Down Expand Up @@ -2000,6 +2005,7 @@ def _break_graph_in_for_loop(
'''
# 0. prepare sub functions
# 0.1 find the range of loop body
assert for_iter.jump_to is not None
loop_body_start_idx = self.indexof(for_iter) + 1
loop_body_end_idx = self.indexof(for_iter.jump_to)
curent_stack = 1
Expand Down Expand Up @@ -2099,10 +2105,14 @@ def _break_graph_in_for_loop(
self._graph.pycode_gen.gen_store(name, self._code)

# 6. add jump if break
jump_if_break = self._graph.pycode_gen._add_instr("POP_JUMP_IF_FALSE")
jump_if_break = self._graph.pycode_gen.gen_pop_jump(
direction=JumpDirection.FORWARD, suffix=PopJumpCond.FALSE
)

# 7. add JUMP_ABSOLUTE to FOR_ITER
self._graph.pycode_gen._add_instr("JUMP_ABSOLUTE", jump_to=for_iter)
# 7. jump back to FOR_ITER
self._graph.pycode_gen.gen_jump(
for_iter, direction=JumpDirection.BACKWARD
)
nop = self._graph.pycode_gen._add_instr("NOP")
for_iter.jump_to = nop
jump_if_break.jump_to = nop
Expand All @@ -2129,6 +2139,7 @@ def _break_graph_in_for_loop(
def _inline_call_for_loop(
self, iterator: VariableBase, for_iter: Instruction
):
assert for_iter.jump_to is not None
pycode_gen = PyCodeGen(self._frame)
origin_instrs = get_instructions(pycode_gen._origin_code)

Expand All @@ -2138,6 +2149,7 @@ def _inline_call_for_loop(
all_used_vars = analysis_used_names_with_space(
origin_instrs, start_idx, end_idx
)

inputs = [
k
for k, v in all_used_vars.items()
Expand All @@ -2152,14 +2164,16 @@ def _inline_call_for_loop(

# 3. add break, continue marker and relocate jump
for_iter_instr = origin_instrs[start_idx]
assert for_iter_instr.jump_to is not None
out_loop_instr = for_iter_instr.jump_to

break_jump = pycode_gen._add_instr(
"JUMP_ABSOLUTE", jump_to=out_loop_instr
)
pycode_gen.gen_jump(out_loop_instr, direction=JumpDirection.FORWARD)
nop_for_continue = pycode_gen._add_instr("NOP")

jump = pycode_gen._add_instr("JUMP_ABSOLUTE", jump_to=for_iter_instr)
jump = pycode_gen.gen_jump(
for_iter_instr, direction=JumpDirection.BACKWARD
)

nop_for_break = pycode_gen._add_instr("NOP")

for instr in pycode_gen._instructions:
Expand Down
31 changes: 30 additions & 1 deletion sot/opcode_translator/executor/pycode_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from ..instruction_utils.opcode_info import (
PYOPCODE_CACHE_SIZE,
UNCONDITIONAL_JUMP,
JumpDirection,
PopJumpCond,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -731,7 +733,7 @@ def gen_push_null(self):
self.gen_store_fast('sys')
self.gen_load_fast('sys')
self.gen_load_method('getsizeof')
self._add_instr("POP_TOP")
self.gen_pop_top()

def gen_store_fast(self, name):
if name not in self._code_options["co_varnames"]:
Expand Down Expand Up @@ -871,6 +873,33 @@ def gen_swap(self, n):
else:
raise NotImplementedError("swap is not supported before python3.11")

def gen_jump(
self,
jump_to: Instruction | None = None,
*,
direction: JumpDirection = JumpDirection.FORWARD,
) -> Instruction:
if sys.version_info >= (3, 11):
return self._add_instr(f"JUMP_{direction.value}", jump_to=jump_to)
else:
return self._add_instr("JUMP_ABSOLUTE", jump_to=jump_to)

def gen_pop_jump(
self,
jump_to: Instruction | None = None,
*,
direction: JumpDirection = JumpDirection.FORWARD,
suffix: PopJumpCond = PopJumpCond.NONE,
) -> Instruction:
if sys.version_info >= (3, 11):
return self._add_instr(
f"POP_JUMP_{direction.value}_IF_{suffix.value}", jump_to=jump_to
)
else:
return self._add_instr(
f"POP_JUMP_IF_{suffix.value}", jump_to=jump_to
)

def gen_return(self):
self._add_instr("RETURN_VALUE")

Expand Down
52 changes: 44 additions & 8 deletions sot/opcode_translator/instruction_utils/instruction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import sys
from typing import TYPE_CHECKING, Any

from .opcode_info import ALL_JUMP, REL_JUMP
from ...utils import InnerError
from .opcode_info import ABS_JUMP, ALL_JUMP, REL_BWD_JUMP, REL_JUMP

if TYPE_CHECKING:
import types
Expand Down Expand Up @@ -155,6 +156,40 @@ def reset_offset(instructions: list[Instruction]) -> None:
instr.offset = idx * 2


def correct_jump_direction(instr: Instruction, arg: int) -> Instruction:
"""
Corrects the jump direction of the given instruction.
NOTE(zrr1999): In Python 3.11, JUMP_ABSOLUTE is removed, so python generates JUMP_FORWARD or JUMP_BACKWARD instead,
but in for loop breakgraph, we reuse JUMP_BACKWARD to jump forward, so we need to change it to JUMP_FORWARD.

Args:
instr (Instruction): The instruction to be corrected.
"""
if instr.opname in ABS_JUMP:
instr.arg = arg
return instr
elif instr.opname in REL_JUMP:
if arg < 0:
if instr.opname in REL_BWD_JUMP:
forward_op_name = instr.opname.replace("BACKWARD", "FORWARD")
if forward_op_name not in dis.opmap:
raise InnerError(f"Unknown jump type {instr.opname}")
instr.opname = forward_op_name
instr.opcode = dis.opmap[forward_op_name]
else: # instr.opname in REL_FWD_JUMP
backward_op_name = instr.opname.replace("FORWARD", "BACKWARD")
if backward_op_name not in dis.opmap:
raise InnerError(f"Unknown jump type {instr.opname}")
instr.opname = backward_op_name
instr.opcode = dis.opmap[backward_op_name]
instr.arg = -arg
else:
instr.arg = arg
return instr
else:
raise ValueError(f"unknown jump type: {instr.opname}")


def relocate_jump_target(instructions: list[Instruction]) -> None:
"""
If a jump instruction is found, this function will adjust the jump targets based on the presence of EXTENDED_ARG instructions.
Expand Down Expand Up @@ -183,16 +218,19 @@ def relocate_jump_target(instructions: list[Instruction]) -> None:
)
assert jump_target is not None

if instr.opname in REL_JUMP:
new_arg = jump_target - instr.offset - 2
else: # instr.opname in ABS_JUMP
if instr.opname in ABS_JUMP:
new_arg = jump_target
else: # instr.opname in REL_JUMP
new_arg = jump_target - instr.offset - 2
if instr.opname in REL_BWD_JUMP:
new_arg = -new_arg

if sys.version_info >= (3, 10):
new_arg //= 2

correct_jump_direction(instr, new_arg)
assert instr.arg is not None
if extended_arg:
instr.arg = new_arg & 0xFF
instr.arg &= 0xFF
new_arg = new_arg >> 8
for ex in reversed(extended_arg):
ex.arg = new_arg & 0xFF
Expand All @@ -202,8 +240,6 @@ def relocate_jump_target(instructions: list[Instruction]) -> None:
# set arg in the first extended_arg
if new_arg > 0:
extended_arg[0].arg += new_arg << 8
else:
instr.arg = new_arg
extended_arg.clear()


Expand Down
15 changes: 15 additions & 0 deletions sot/opcode_translator/instruction_utils/opcode_info.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import sys
from enum import Enum

import opcode

REL_JUMP = {opcode.opname[x] for x in opcode.hasjrel}
REL_BWD_JUMP = {opname for opname in REL_JUMP if "BACKWARD" in opname}
REL_FWD_JUMP = REL_JUMP - REL_BWD_JUMP
ABS_JUMP = {opcode.opname[x] for x in opcode.hasjabs}
HAS_LOCAL = {opcode.opname[x] for x in opcode.haslocal}
HAS_FREE = {opcode.opname[x] for x in opcode.hasfree}
Expand All @@ -12,6 +15,18 @@
UNCONDITIONAL_JUMP.add("JUMP_BACKWARD")


class JumpDirection(Enum):
FORWARD = "FORWARD"
BACKWARD = "BACKWARD"


class PopJumpCond(Enum):
FALSE = "FALSE"
TRUE = "TRUE"
NONE = "NONE"
NOT_NONE = "NOT_NONE"


# Cache for some opcodes, it's for Python 3.11+
# https://github.com/python/cpython/blob/3.11/Include/internal/pycore_opcode.h#L41-L53
PYOPCODE_CACHE_SIZE = {
Expand Down
4 changes: 2 additions & 2 deletions sot/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,8 @@ def __init__(self):
def clear(self):
self.graph_num = 0
self.op_num = 0
self.graphs: list = []
self.ops: list = []
self.graphs = []
self.ops = []

def get_graph_num(self):
return self.graph_num
Expand Down
8 changes: 0 additions & 8 deletions tests/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,9 @@ echo "IS_PY311:" $IS_PY311
failed_tests=()

py311_skiped_tests=(
./test_12_for_loop.py
./test_15_slice.py
./test_19_closure.py
./test_21_global.py
./test_enumerate.py
./test_guard_user_defined_fn.py
./test_inplace_api.py
./test_range.py
./test_resnet.py
./test_resnet50_backward.py
# ./test_side_effects.py There are some case need to be fixed
./test_tensor_dtype_in_guard.py
)

Expand Down
8 changes: 0 additions & 8 deletions tests/test_side_effects.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import sys
import unittest

from test_case_base import TestCaseBase, strict_mode_guard
Expand Down Expand Up @@ -234,10 +233,6 @@ def test_list_reverse(self):
list_reverse, [-1, 2, -3, 4, -5, 6, -7, 8, -9]
)

@unittest.skipIf(
sys.version_info >= (3, 11),
"Python 3.11+ not support for-loop breakgraph",
)
def test_slice_in_for_loop(self):
x = 2
with strict_mode_guard(0):
Expand All @@ -247,9 +242,6 @@ def test_list_nested(self):
self.assert_results_with_side_effects(list_nested, [1, 2, 3])


@unittest.skipIf(
sys.version_info >= (3, 11), "Python 3.11+ not support for-loop breakgraph"
)
class TestSliceAfterChange(TestCaseBase):
def test_slice_list_after_change(self):
self.assert_results_with_side_effects(
Expand Down