Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St][AMP] add should_auto_cast attribute for each operator #58628

Merged
merged 9 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -2935,6 +2935,9 @@ def __init__(
# attr for static graph mode cuda graph
self._cuda_graph_attr = _current_cuda_graph_mode

# attr for OP should cast in AMP mode
self._should_auto_cast: bool = True

op_maker = core.op_proto_and_checker_maker

if op_maker.kOpRoleAttrName() not in op_attrs:
Expand Down Expand Up @@ -3692,6 +3695,25 @@ def dist_attr(self, dist_attr):
"""
self.desc.dist_attr = dist_attr

def set_auto_cast(self, auto_cast):
"""
Set auto cast attribute of this Operator.

Args:
auto_cast(bool): True if this Operator should cast in AMP mode.
"""
self._should_auto_cast = auto_cast

@property
def should_auto_cast(self):
"""
Get auto cast attribute of this Operator.

Returns:
bool: True if this Operator should cast in AMP mode.
"""
return self._should_auto_cast


@signature_safe_contextmanager
def _stride_in_no_check_dy2st_diff():
Expand Down Expand Up @@ -6323,6 +6345,7 @@ def clone(self, for_test=False):
p._copy_param_info_from(self)
p._copy_data_info_from(self, pruned_origin_block_id_map)
p._copy_dist_param_info_from(self)
p._copy_operator_info_from(self)
return p

def _prune(self, targets):
Expand Down Expand Up @@ -6446,6 +6469,7 @@ def _prune_with_input(self, feeded_var_names, targets):
res._copy_param_info_from(self)
res._copy_data_info_from(self, pruned_origin_block_id_map)
res._copy_dist_param_info_from(self)
res._copy_operator_info_from(self)

return res

Expand Down Expand Up @@ -6961,6 +6985,24 @@ def _copy_data_info_from(self, other, pruned_origin_block_id_map=None):
if other_var.stop_gradient:
var.stop_gradient = True

def _copy_operator_info_from(self, other: "Program"):
"""
Copy the information of Operator information from other program.

Args:
other(Program): Other program

Returns:
None
"""
if not isinstance(other, Program):
raise TypeError(
f"Function Program._copy_operator_info_from() needs to pass in a source Program, but received {type(other)}"
)
for dst_block, src_block in zip(self.blocks, other.blocks):
for dst_op, src_op in zip(dst_block.ops, src_block.ops):
dst_op.set_auto_cast(src_op.should_auto_cast)

def list_vars(self):
"""
Get all Tensors from this Program. A iterable object is returned.
Expand Down
38 changes: 38 additions & 0 deletions python/paddle/jit/dy2static/convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import re
from contextlib import contextmanager

import paddle
from paddle.autograd.py_layer import PyLayerMeta
from paddle.base.data_feeder import convert_dtype
from paddle.base.dygraph.base import _convert_into_variable, in_to_static_mode
from paddle.base.framework import Variable, core, default_main_program
from paddle.pir import OpResult
from paddle.static.amp.fp16_utils import AmpOptions

from .py_layer import StaticPyLayer
from .utils import (
Expand Down Expand Up @@ -77,6 +81,9 @@ def convert_load(x):
if new_var is not None:
return new_var

if x is paddle.amp.auto_cast:
return convert_auto_cast

return x


Expand Down Expand Up @@ -805,6 +812,37 @@ def convert_pop(target, *args):
return _run_python_pop(target, *args)


@contextmanager
def convert_auto_cast(
enable=True,
custom_white_list=None,
custom_black_list=None,
level='O1',
dtype='float16',
use_promote=True,
):
from .program_translator import ProgramTranslator

if enable:
raise NotImplementedError("Does not support local switching on amp now")

amp_records = ProgramTranslator.get_instance()._amp_records
main_program = paddle.static.default_main_program()
current_block_idx = main_program.current_block_idx
current_block = main_program.current_block()
start_op_idx = len(current_block.ops)
amp_options = AmpOptions(
enable, custom_white_list, custom_black_list, level, dtype, use_promote
)
yield
end_op_idx = len(current_block.ops)
if current_block_idx not in amp_records:
amp_records[current_block_idx] = []
amp_records[current_block_idx].append(
(amp_options, start_op_idx, end_op_idx)
)


def _run_paddle_pop(array, *args):
if len(args) == 0:
idx = -1
Expand Down
8 changes: 8 additions & 0 deletions python/paddle/jit/dy2static/program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import collections
import inspect
import os
import threading
import warnings
import weakref
from typing import TYPE_CHECKING

import paddle.pir.core as ir_static
from paddle import decomposition
Expand Down Expand Up @@ -65,6 +68,9 @@
unwrap,
)

if TYPE_CHECKING:
from paddle.static.amp.fp16_utils import AmpOptions

__all__ = []

# For each traced function, we set `max_traced_program_count` = 10 to consider caching performance.
Expand Down Expand Up @@ -1282,6 +1288,7 @@ def from_func_spec(
)

new_name_generator = UniqueNameGenerator()
ProgramTranslator.get_instance()._amp_records.clear()

with framework.program_guard(main_program, startup_program):
with _to_static_mode_guard_(is_to_static=True), UniqueNameGuard(
Expand Down Expand Up @@ -1753,6 +1760,7 @@ def __init__(self):
self._program_cache = ProgramCache()
self._params_recorder = ParametersRecorder()
self._inplace_map = InplaceMap()
self._amp_records: dict[int, list[tuple[AmpOptions, int, int]]] = {}
self.enable_to_static = True

def enable(self, enable_to_static):
Expand Down
47 changes: 47 additions & 0 deletions python/paddle/static/amp/fp16_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import logging
from dataclasses import dataclass

import numpy as np

Expand Down Expand Up @@ -40,6 +43,16 @@
_fp16_guard_pattern = "__use_fp16__"


@dataclass
class AmpOptions:
enable: bool
custom_white_list: list[str] | None
custom_black_list: list[str] | None
level: str
dtype: str
use_promote: bool


def _rename_arg(op, old_name, new_name):
"""
If an op has old_name input and output, rename these input
Expand Down Expand Up @@ -586,6 +599,40 @@ def process_op_input_and_outputs(op, block, global_block, dtype):
return low_precison_var_names


def map_block(block, fn, parent_op=None):
fn(block, parent_op)
program = block.program
for op in block.ops:
if not op.has_attr("sub_block"):
continue
sub_block = program.blocks[op.attr("sub_block").id]
map_block(sub_block, fn, op)


def prepare_op_should_auto_cast(
program: paddle.static.Program,
amp_records: dict[int, list[tuple[AmpOptions, int, int]]],
):
amp_enable_op_map: dict[paddle.static.Operator, bool] = {}

def fill_amp_enable_op_map(block, parent_op):
block_idx = block.idx
ops = block.ops
for op in ops:
# The top level should be FP16
current_op_amp_options = amp_enable_op_map.get(parent_op, True)
if block_idx in amp_records:
for amp_options, start, end in amp_records[block_idx]:
if op.idx in range(start, end):
current_op_amp_options = amp_options.enable
break
amp_enable_op_map[op] = current_op_amp_options

map_block(program.global_block(), fill_amp_enable_op_map)
for op, enable in amp_enable_op_map.items():
op.set_auto_cast(enable)


def cast_model_to_fp16(
program,
amp_lists=None,
Expand Down
128 changes: 128 additions & 0 deletions test/dygraph_to_static/test_local_cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# NOTE(SigureMo): This unittest does NOT need to run in PIR mode. Don't import Dy2StTestBase.

import unittest

import paddle
from paddle.jit.dy2static.program_translator import ProgramTranslator
from paddle.static.amp.fp16_utils import prepare_op_should_auto_cast


class LocalAutoCastLayer1(paddle.nn.Layer):
def __init__(self):
super().__init__()
self._fc = paddle.nn.Linear(10, 10)

@paddle.jit.to_static(full_graph=True)
def forward(self, x):
x = self._fc(x)
y = self._fc(x) * 2
with paddle.amp.auto_cast(False):
x = x.astype("float32")
y = y.astype("float32")
if x[0][0] > 1:
x = x + y
else:
x = x - y
x = x * 2

return x + 1


class LocalAutoCastLayer2(paddle.nn.Layer):
def __init__(self):
super().__init__()
self._fc = paddle.nn.Linear(10, 10)

@paddle.jit.to_static(full_graph=True)
def forward(self, x):
with paddle.amp.auto_cast(False):
x = x.astype("float32")
x = self._fc(x)
y = self._fc(x) * 2
if x[0][0] > 1:
x = x + y
else:
x = x - y
x = x * 2

return x + 1


class TestLocalCast(unittest.TestCase):
def get_auto_cast_ops_info_from_program(self, program):
auto_cast_ops_info = []
for block in program.blocks:
current_block_should_auto_cast = []
auto_cast_ops_info.append(current_block_should_auto_cast)
for op in block.ops:
current_block_should_auto_cast.append(op.should_auto_cast)
return auto_cast_ops_info

def should_auto_cast_for_each_ops(self, layer, input):
concrete_program, _ = layer.forward.get_concrete_program(input)
program = concrete_program.main_program
prepare_op_should_auto_cast(
program, ProgramTranslator.get_instance()._amp_records
)
auto_cast_ops_info = self.get_auto_cast_ops_info_from_program(program)
paddle.enable_static()
cloned_program = program.clone()
paddle.disable_static()
cloned_auto_cast_ops_info = self.get_auto_cast_ops_info_from_program(
cloned_program
)
self.assertEqual(auto_cast_ops_info, cloned_auto_cast_ops_info)
return auto_cast_ops_info

def test_should_auto_cast_1(self):
layer = LocalAutoCastLayer1()
input = paddle.randn([10, 10])
expected = [
# There are part of ops in auto_cast(False) block
[
True, True, True, True, True,
False, False, False, False, False, False, False, False, False, False, False,
True,
],
# All if branch in auto_cast(False) block
[False, False],
# All else branch in auto_cast(False) block
[False, False, False],
] # fmt: skip
actual = self.should_auto_cast_for_each_ops(layer, input)
self.assertEqual(expected, actual)

def test_should_auto_cast_2(self):
layer = LocalAutoCastLayer2()
input = paddle.randn([10, 10])
expected = [
# There are part of ops in auto_cast(False) block
[
False, False, False, False, False, False,
True, True, True, True, True, True, True, True, True, True,
],
# All if branch out of auto_cast(False) block
[True, True],
# All else branch out of auto_cast(False) block
[True, True, True],
] # fmt: skip
actual = self.should_auto_cast_for_each_ops(layer, input)
self.assertEqual(expected, actual)


if __name__ == '__main__':
unittest.main()