Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St] Refine dy2st unittest decorators name #58316

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions python/paddle/jit/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import types
import warnings
from collections import OrderedDict
from contextlib import contextmanager
from typing import Any

import paddle
Expand Down Expand Up @@ -55,6 +56,10 @@
from paddle.framework import in_dynamic_mode
from paddle.nn import Layer
from paddle.static.io import save_inference_model
from paddle.utils.environments import (
BooleanEnvironmentVariable,
EnvironmentVariableGuard,
)

from .dy2static import logging_utils
from .dy2static.convert_call_func import ConversionOptions, add_ignore_module
Expand All @@ -74,6 +79,14 @@
TranslatedLayer,
)

ENV_ENABLE_SOT = BooleanEnvironmentVariable("ENABLE_FALL_BACK", True)


@contextmanager
def sot_mode_guard(value: bool):
with EnvironmentVariableGuard(ENV_ENABLE_SOT, value):
yield


def create_program_from_desc(program_desc):
program = Program()
Expand Down Expand Up @@ -294,11 +307,8 @@ def decorated(python_func):

nonlocal full_graph
if full_graph is None:
flag = os.environ.get("ENABLE_FALL_BACK", None)
if flag == "True" or flag is None:
full_graph = False
else: # False
full_graph = True
flag = ENV_ENABLE_SOT.get()
full_graph = not flag

if sys.version_info >= (3, 12) and not full_graph:
warnings.warn(
Expand Down
45 changes: 12 additions & 33 deletions test/dygraph_to_static/dygraph_to_static_utils_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import inspect
import logging
import os
Expand All @@ -24,6 +23,7 @@

from paddle import set_flags, static
from paddle.base import core
from paddle.jit.api import sot_mode_guard

"""
# Usage:
Expand Down Expand Up @@ -69,21 +69,6 @@ def lower_case_name(self):
DEFAULT_IR_MODE = IrMode.LEGACY_PROGRAM


def in_sot_mode():
return os.getenv("ENABLE_FALL_BACK", "False") == "True"


@contextlib.contextmanager
def enable_fallback_guard(enable):
flag = os.environ.get("ENABLE_FALL_BACK", None)
os.environ["ENABLE_FALL_BACK"] = enable
yield
if flag is not None:
os.environ["ENABLE_FALL_BACK"] = flag
else:
del os.environ["ENABLE_FALL_BACK"]


def to_legacy_ast_test(fn):
"""
convert run fall_back to ast
Expand All @@ -92,7 +77,7 @@ def to_legacy_ast_test(fn):
@wraps(fn)
def impl(*args, **kwargs):
logger.info("[AST] running AST")
with enable_fallback_guard("False"):
with sot_mode_guard(False):
fn(*args, **kwargs)

return impl
Expand All @@ -106,7 +91,7 @@ def to_sot_test(fn):
@wraps(fn)
def impl(*args, **kwargs):
logger.info("[SOT] running SOT")
with enable_fallback_guard("True"):
with sot_mode_guard(True):
fn(*args, **kwargs)

return impl
Expand Down Expand Up @@ -263,22 +248,27 @@ def decorator(fn):

# Suger decorators
# These decorators can be simply composed by base decorators
def ast_only_test(fn):
def test_ast_only(fn):
fn = set_to_static_mode(ToStaticMode.LEGACY_AST)(fn)
return fn


def sot_only_test(fn):
def test_sot_only(fn):
fn = set_to_static_mode(ToStaticMode.SOT)(fn)
return fn


def test_with_new_ir(fn):
def test_pir_only(fn):
fn = set_ir_mode(IrMode.PIR)(fn)
return fn


def _test_and_compare_with_new_ir(fn):
def test_legacy_and_pir(fn):
fn = set_ir_mode(IrMode.LEGACY_PROGRAM | IrMode.PIR)(fn)
return fn


def compare_legacy_with_pir(fn):
@wraps(fn)
def impl(*args, **kwargs):
outs = fn(*args, **kwargs)
Expand All @@ -297,17 +287,6 @@ def impl(*args, **kwargs):
return impl


def test_and_compare_with_new_ir(need_check_output: bool = True):
def decorator(fn):
fn = set_ir_mode(IrMode.LEGACY_PROGRAM | IrMode.PIR)(fn)
if need_check_output:
logger.info(f"[need_check_output] {fn.__name__}")
fn = _test_and_compare_with_new_ir(fn)
return fn

return decorator


# For debug
def show_all_test_cases(test_class):
logger.info(f"[showing {test_class.__name__}]")
Expand Down
17 changes: 8 additions & 9 deletions test/dygraph_to_static/test_assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import numpy
from dygraph_to_static_utils_new import (
Dy2StTestBase,
ast_only_test,
test_and_compare_with_new_ir,
test_ast_only,
test_legacy_and_pir,
)

import paddle
Expand All @@ -37,7 +37,6 @@ def dyfunc_assert_non_variable(x=True):
assert x


# @dy2static_unittest
class TestAssertVariable(Dy2StTestBase):
def _run(self, func, x, with_exception, to_static):
paddle.jit.enable_to_static(to_static)
Expand All @@ -53,8 +52,8 @@ def _run_dy_static(self, func, x, with_exception):
self._run(func, x, with_exception, True)
self._run(func, x, with_exception, False)

@test_and_compare_with_new_ir(False)
@ast_only_test
@test_legacy_and_pir
@test_ast_only
def test_non_variable(self):
self._run_dy_static(
dyfunc_assert_non_variable, x=False, with_exception=True
Expand All @@ -63,8 +62,8 @@ def test_non_variable(self):
dyfunc_assert_non_variable, x=True, with_exception=False
)

@test_and_compare_with_new_ir(False)
@ast_only_test
@test_legacy_and_pir
@test_ast_only
def test_bool_variable(self):
self._run_dy_static(
dyfunc_assert_variable, x=numpy.array([False]), with_exception=True
Expand All @@ -73,8 +72,8 @@ def test_bool_variable(self):
dyfunc_assert_variable, x=numpy.array([True]), with_exception=False
)

@test_and_compare_with_new_ir(False)
@ast_only_test
@test_legacy_and_pir
@test_ast_only
def test_int_variable(self):
self._run_dy_static(
dyfunc_assert_variable, x=numpy.array([0]), with_exception=True
Expand Down
15 changes: 7 additions & 8 deletions test/dygraph_to_static/test_ast_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import numpy as np
from dygraph_to_static_utils_new import (
Dy2StTestBase,
ast_only_test,
test_and_compare_with_new_ir,
test_ast_only,
test_legacy_and_pir,
)
from ifelse_simple_func import (
dyfunc_with_if_else,
Expand All @@ -35,7 +35,6 @@
from paddle.utils import gast


# @dy2static_unittest
class TestAST2Func(Dy2StTestBase):
"""
TestCase for the transformation from ast.AST into python callable function.
Expand All @@ -48,15 +47,15 @@ def _ast2func(self, func):
transformed_func, _ = ast_to_func(ast_root, func)
return transformed_func

@ast_only_test
@test_ast_only
def test_ast2func(self):
def func(x, y):
return x + y

x, y = 10, 20
self.assertEqual(func(x, y), self._ast2func(func)(x, y))

@ast_only_test
@test_ast_only
def test_ast2func_dygraph(self):
paddle.disable_static()
funcs = [dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else]
Expand All @@ -68,8 +67,8 @@ def test_ast2func_dygraph(self):
test_ret = self._ast2func(func)(x_v).numpy()
self.assertTrue((true_ret == test_ret).all())

@test_and_compare_with_new_ir(False)
@ast_only_test
@test_legacy_and_pir
@test_ast_only
def test_ast2func_static(self):
paddle.enable_static()

Expand All @@ -88,7 +87,7 @@ def func(x):
ret = exe.run(main_program, fetch_list=[true_ret, test_ret])
self.assertTrue((ret[0] == ret[1]).all())

@ast_only_test
@test_ast_only
def test_ast2func_error(self):
with self.assertRaises(Exception) as e:
self.assertRaises(TypeError, ast_to_func("x = a + b", 'foo'))
Expand Down
11 changes: 3 additions & 8 deletions test/dygraph_to_static/test_backward_without_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
import unittest

import numpy as np
from dygraph_to_static_utils_new import (
Dy2StTestBase,
test_and_compare_with_new_ir,
)
from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir

import paddle

Expand All @@ -32,9 +29,8 @@ def forward(self, x):
return out


# @dy2static_unittest
class TestBackwardWithoutParams(Dy2StTestBase):
@test_and_compare_with_new_ir(False)
@test_legacy_and_pir
def test_run(self):
net = paddle.jit.to_static(Net())

Expand All @@ -57,9 +53,8 @@ def forward(self, x):
return y, out


# @dy2static_unittest
class TestZeroSizeNet(Dy2StTestBase):
@test_and_compare_with_new_ir(False)
@test_legacy_and_pir
def test_run(self):
net = paddle.jit.to_static(ZeroSizeNet())
x = paddle.ones([2, 2])
Expand Down
18 changes: 8 additions & 10 deletions test/dygraph_to_static/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import numpy as np
from dygraph_to_static_utils_new import (
Dy2StTestBase,
ast_only_test,
test_and_compare_with_new_ir,
test_ast_only,
test_legacy_and_pir,
)

from paddle import base
Expand Down Expand Up @@ -60,7 +60,6 @@ def test_mix_cast(x):
return x


# @dy2static_unittest
class TestCastBase(Dy2StTestBase):
def setUp(self):
self.place = (
Expand Down Expand Up @@ -89,9 +88,8 @@ def do_test(self):
res = self.func(self.input)
return res

@ast_only_test # TODO: add new symbolic only test.
@test_and_compare_with_new_ir(False)
# @set_to_static_mode(ToStaticMode.LEGACY_AST)
@test_ast_only # TODO: add new sot only test.
@test_legacy_and_pir
def test_cast_result(self):
res = self.do_test().numpy()
self.assertTrue(
Expand Down Expand Up @@ -156,8 +154,8 @@ def prepare(self):
def set_func(self):
self.func = to_static(full_graph=True)(test_mix_cast)

@ast_only_test # TODO: add new symbolic only test.
@test_and_compare_with_new_ir(False)
@test_ast_only # TODO: add new symbolic only test.
@test_legacy_and_pir
def test_cast_result(self):
res = self.do_test().numpy()
self.assertTrue(
Expand Down Expand Up @@ -188,8 +186,8 @@ def prepare(self):
def set_func(self):
self.func = to_static(full_graph=True)(test_not_var_cast)

@ast_only_test
@test_and_compare_with_new_ir(False)
@test_ast_only
@test_legacy_and_pir
def test_cast_result(self):
# breakpoint()
# print("run once!!!")
Expand Down
Loading