diff --git a/test/dygraph_to_static/dygraph_to_static_utils_new.py b/test/dygraph_to_static/dygraph_to_static_utils_new.py index 442440c2427e4e..48417f1b33efcd 100644 --- a/test/dygraph_to_static/dygraph_to_static_utils_new.py +++ b/test/dygraph_to_static/dygraph_to_static_utils_new.py @@ -27,6 +27,10 @@ from paddle import set_flags, static from paddle.base import core from paddle.jit.api import sot_mode_guard +from paddle.jit.sot.opcode_translator.executor.executor_cache import ( + OpcodeExecutorCache, +) +from paddle.jit.sot.utils.envs import min_graph_size_guard """ # Usage: @@ -54,6 +58,8 @@ def test_case1(self): class ToStaticMode(Flag): AST = auto() SOT = auto() + # SOT with MIN_GRAPH_SIZE=10, we only test SOT_MGS10 + LEGACY_IR to avoid regression + SOT_MGS10 = auto() def lower_case_name(self): return self.name.lower() @@ -70,13 +76,15 @@ def lower_case_name(self): return self.name.lower() -DEFAULT_TO_STATIC_MODE = ToStaticMode.AST | ToStaticMode.SOT +DEFAULT_TO_STATIC_MODE = ( + ToStaticMode.AST | ToStaticMode.SOT | ToStaticMode.SOT_MGS10 +) DEFAULT_IR_MODE = IrMode.LEGACY_IR def to_legacy_ast_test(fn): """ - convert run fall_back to ast + convert run AST """ @wraps(fn) @@ -90,14 +98,34 @@ def impl(*args, **kwargs): def to_sot_test(fn): """ - convert run fall_back to ast + convert run SOT """ @wraps(fn) def impl(*args, **kwargs): logger.info("[SOT] running SOT") + + OpcodeExecutorCache().clear() with sot_mode_guard(True): - fn(*args, **kwargs) + with min_graph_size_guard(0): + fn(*args, **kwargs) + + return impl + + +def to_sot_mgs10_test(fn): + """ + convert run SOT and MIN_GRAPH_SIZE=10 + """ + + @wraps(fn) + def impl(*args, **kwargs): + logger.info("[SOT_MGS10] running SOT") + + OpcodeExecutorCache().clear() + with sot_mode_guard(True): + with min_graph_size_guard(10): + fn(*args, **kwargs) return impl @@ -148,8 +176,9 @@ def impl(*args, **kwargs): # Metaclass and BaseClass class Dy2StTestMeta(type): TO_STATIC_HANDLER_MAP = { - ToStaticMode.SOT: to_sot_test, ToStaticMode.AST: to_legacy_ast_test, + ToStaticMode.SOT: to_sot_test, + ToStaticMode.SOT_MGS10: to_sot_mgs10_test, } IR_HANDLER_MAP = { @@ -204,6 +233,12 @@ def __new__(cls, name, bases, attrs): ) # Generate all test cases for to_static_mode, ir_mode in to_static_with_ir_modes: + if ( + to_static_mode == ToStaticMode.SOT_MGS10 + and ir_mode != IrMode.LEGACY_IR + ): + # SOT_MGS10 only test with LEGACY_IR + continue new_attrs[ Dy2StTestMeta.test_case_name( fn_name, to_static_mode, ir_mode @@ -262,7 +297,7 @@ def test_ast_only(fn): def test_sot_only(fn): - fn = set_to_static_mode(ToStaticMode.SOT)(fn) + fn = set_to_static_mode(ToStaticMode.SOT | ToStaticMode.SOT_MGS10)(fn) return fn diff --git a/test/dygraph_to_static/test_gradname_parse.py b/test/dygraph_to_static/test_gradname_parse.py index e15320fdc84880..154471a0e6d653 100644 --- a/test/dygraph_to_static/test_gradname_parse.py +++ b/test/dygraph_to_static/test_gradname_parse.py @@ -18,8 +18,7 @@ import numpy as np from dygraph_to_static_utils_new import ( Dy2StTestBase, - test_ast_only, - test_pir_api_only, + test_legacy_and_pir_api, ) import paddle @@ -86,8 +85,7 @@ def setUp(self): self.dy2st_input = (x2,) self.dy2st_grad_input = (x2,) - @test_ast_only - @test_pir_api_only + @test_legacy_and_pir_api def test_run(self): try: dy_out = self.func(*self.dy_input) diff --git a/test/dygraph_to_static/test_inplace_assign.py b/test/dygraph_to_static/test_inplace_assign.py index 8e3a19c62764c5..05cd4274ff6eee 100644 --- a/test/dygraph_to_static/test_inplace_assign.py +++ b/test/dygraph_to_static/test_inplace_assign.py @@ -54,6 +54,7 @@ def func(x): @test_legacy_and_pir def test_case2(self): def func(a, x): + x = 2 * x x[:] = a * 2.0 return x diff --git a/test/dygraph_to_static/test_param_guard.py b/test/dygraph_to_static/test_param_guard.py index 8e2e917c6af053..1e571e925a89dd 100644 --- a/test/dygraph_to_static/test_param_guard.py +++ b/test/dygraph_to_static/test_param_guard.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import unittest import numpy as np diff --git a/test/dygraph_to_static/test_seq2seq.py b/test/dygraph_to_static/test_seq2seq.py index 743b115583b39c..6cf35f58d0a6cd 100644 --- a/test/dygraph_to_static/test_seq2seq.py +++ b/test/dygraph_to_static/test_seq2seq.py @@ -20,6 +20,9 @@ import numpy as np from dygraph_to_static_utils_new import ( Dy2StTestBase, + IrMode, + ToStaticMode, + disable_test_case, ) from seq2seq_dygraph_model import AttentionModel, BaseModel from seq2seq_utils import Seq2SeqModelHyperParams, get_data_iter @@ -236,10 +239,13 @@ def _test_predict(self, attn_model=False): msg=f"\npred_dygraph = {pred_dygraph} \npred_static = {pred_static}", ) + # Disable duplicated test case to avoid timeout + @disable_test_case((ToStaticMode.SOT_MGS10, IrMode.LEGACY_IR)) def test_base_model(self): self._test_train(attn_model=False) self._test_predict(attn_model=False) + @disable_test_case((ToStaticMode.SOT_MGS10, IrMode.LEGACY_IR)) def test_attn_model(self): self._test_train(attn_model=True) # TODO(liym27): add predict diff --git a/test/dygraph_to_static/test_to_tensor.py b/test/dygraph_to_static/test_to_tensor.py index a0e29c2ed07481..bbd7568ffada94 100644 --- a/test/dygraph_to_static/test_to_tensor.py +++ b/test/dygraph_to_static/test_to_tensor.py @@ -17,6 +17,9 @@ import numpy from dygraph_to_static_utils_new import ( Dy2StTestBase, + IrMode, + ToStaticMode, + disable_test_case, test_legacy_and_pir_exe_and_pir_api, test_legacy_only, test_pir_api_only, @@ -165,7 +168,9 @@ def test_to_tensor_default_dtype(self): self.assertTrue(a.stop_gradient == b.stop_gradient) self.assertTrue(a.place._equals(b.place)) + # MIN_GRAPH_SIZE=10 will cause fallback and raise error in dygraph @test_legacy_and_pir_exe_and_pir_api + @disable_test_case((ToStaticMode.SOT_MGS10, IrMode.LEGACY_IR)) def test_to_tensor_err_log(self): paddle.disable_static() x = paddle.to_tensor([3])