Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SOT] Add MIN_GRAPH_SIZE=10 test in dy2st tests #59191

Merged
merged 8 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions test/dygraph_to_static/dygraph_to_static_utils_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
from paddle import set_flags, static
from paddle.base import core
from paddle.jit.api import sot_mode_guard
from paddle.jit.sot.opcode_translator.executor.executor_cache import (
OpcodeExecutorCache,
)
from paddle.jit.sot.utils.envs import min_graph_size_guard

"""
# Usage:
Expand Down Expand Up @@ -54,6 +58,8 @@ def test_case1(self):
class ToStaticMode(Flag):
AST = auto()
SOT = auto()
# SOT with MIN_GRAPH_SIZE=10, we only test SOT_MGS10 + LEGACY_IR to avoid regression
SOT_MGS10 = auto()

def lower_case_name(self):
return self.name.lower()
Expand All @@ -70,13 +76,15 @@ def lower_case_name(self):
return self.name.lower()


DEFAULT_TO_STATIC_MODE = ToStaticMode.AST | ToStaticMode.SOT
DEFAULT_TO_STATIC_MODE = (
ToStaticMode.AST | ToStaticMode.SOT | ToStaticMode.SOT_MGS10
)
DEFAULT_IR_MODE = IrMode.LEGACY_IR


def to_legacy_ast_test(fn):
"""
convert run fall_back to ast
convert run AST
"""

@wraps(fn)
Expand All @@ -90,14 +98,34 @@ def impl(*args, **kwargs):

def to_sot_test(fn):
"""
convert run fall_back to ast
convert run SOT
"""

@wraps(fn)
def impl(*args, **kwargs):
logger.info("[SOT] running SOT")

OpcodeExecutorCache().clear()
with sot_mode_guard(True):
fn(*args, **kwargs)
with min_graph_size_guard(0):
fn(*args, **kwargs)

return impl


def to_sot_mgs10_test(fn):
"""
convert run SOT and MIN_GRAPH_SIZE=10
"""

@wraps(fn)
def impl(*args, **kwargs):
logger.info("[SOT_MGS10] running SOT")

OpcodeExecutorCache().clear()
with sot_mode_guard(True):
with min_graph_size_guard(10):
fn(*args, **kwargs)

return impl

Expand Down Expand Up @@ -148,8 +176,9 @@ def impl(*args, **kwargs):
# Metaclass and BaseClass
class Dy2StTestMeta(type):
TO_STATIC_HANDLER_MAP = {
ToStaticMode.SOT: to_sot_test,
ToStaticMode.AST: to_legacy_ast_test,
ToStaticMode.SOT: to_sot_test,
ToStaticMode.SOT_MGS10: to_sot_mgs10_test,
}

IR_HANDLER_MAP = {
Expand Down Expand Up @@ -204,6 +233,12 @@ def __new__(cls, name, bases, attrs):
)
# Generate all test cases
for to_static_mode, ir_mode in to_static_with_ir_modes:
if (
to_static_mode == ToStaticMode.SOT_MGS10
and ir_mode != IrMode.LEGACY_IR
):
# SOT_MGS10 only test with LEGACY_IR
continue
new_attrs[
Dy2StTestMeta.test_case_name(
fn_name, to_static_mode, ir_mode
Expand Down Expand Up @@ -262,7 +297,7 @@ def test_ast_only(fn):


def test_sot_only(fn):
fn = set_to_static_mode(ToStaticMode.SOT)(fn)
fn = set_to_static_mode(ToStaticMode.SOT | ToStaticMode.SOT_MGS10)(fn)
return fn


Expand Down
6 changes: 2 additions & 4 deletions test/dygraph_to_static/test_gradname_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
import numpy as np
from dygraph_to_static_utils_new import (
Dy2StTestBase,
test_ast_only,
test_pir_api_only,
test_legacy_and_pir_api,
)

import paddle
Expand Down Expand Up @@ -86,8 +85,7 @@ def setUp(self):
self.dy2st_input = (x2,)
self.dy2st_grad_input = (x2,)

@test_ast_only
@test_pir_api_only
@test_legacy_and_pir_api
def test_run(self):
try:
dy_out = self.func(*self.dy_input)
Expand Down
1 change: 1 addition & 0 deletions test/dygraph_to_static/test_inplace_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def func(x):
@test_legacy_and_pir
def test_case2(self):
def func(a, x):
x = 2 * x
x[:] = a * 2.0
return x

Expand Down
1 change: 1 addition & 0 deletions test/dygraph_to_static/test_param_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import unittest

import numpy as np
Expand Down
6 changes: 6 additions & 0 deletions test/dygraph_to_static/test_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import numpy as np
from dygraph_to_static_utils_new import (
Dy2StTestBase,
IrMode,
ToStaticMode,
disable_test_case,
)
from seq2seq_dygraph_model import AttentionModel, BaseModel
from seq2seq_utils import Seq2SeqModelHyperParams, get_data_iter
Expand Down Expand Up @@ -236,10 +239,13 @@ def _test_predict(self, attn_model=False):
msg=f"\npred_dygraph = {pred_dygraph} \npred_static = {pred_static}",
)

# Disable duplicated test case to avoid timeout
@disable_test_case((ToStaticMode.SOT_MGS10, IrMode.LEGACY_IR))
def test_base_model(self):
self._test_train(attn_model=False)
self._test_predict(attn_model=False)

@disable_test_case((ToStaticMode.SOT_MGS10, IrMode.LEGACY_IR))
def test_attn_model(self):
self._test_train(attn_model=True)
# TODO(liym27): add predict
Expand Down
5 changes: 5 additions & 0 deletions test/dygraph_to_static/test_to_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
import numpy
from dygraph_to_static_utils_new import (
Dy2StTestBase,
IrMode,
ToStaticMode,
disable_test_case,
test_legacy_and_pir_exe_and_pir_api,
test_legacy_only,
test_pir_api_only,
Expand Down Expand Up @@ -165,7 +168,9 @@ def test_to_tensor_default_dtype(self):
self.assertTrue(a.stop_gradient == b.stop_gradient)
self.assertTrue(a.place._equals(b.place))

# MIN_GRAPH_SIZE=10 will cause fallback and raise error in dygraph
@test_legacy_and_pir_exe_and_pir_api
@disable_test_case((ToStaticMode.SOT_MGS10, IrMode.LEGACY_IR))
def test_to_tensor_err_log(self):
paddle.disable_static()
x = paddle.to_tensor([3])
Expand Down