diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index 78eeb58a19133d..d378c5a312a8f4 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -2181,6 +2181,27 @@ class OneHotOpPattern } }; +class TemporalShiftOpPattern + : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern< + paddle::dialect::TemporalShiftOp>::OpRewritePattern; + + bool MatchAndRewrite(paddle::dialect::TemporalShiftOp op, + pir::PatternRewriter &rewriter) const override { + if (op->HasAttribute(kCanRunTrtAttr) && + op.attribute(kCanRunTrtAttr).data()) { + return false; + } + if (!op->HasAttribute("shift_ratio") || !op->HasAttribute("seg_num")) { + VLOG(3) << "temporal shift need attributes : shift_ratio and seg_num"; + return false; + } + op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); + return true; + } +}; + class InstanceNormOpPattern : public pir::OpRewritePattern { public: @@ -2388,6 +2409,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass { ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); + ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); return ps; diff --git a/python/paddle/tensorrt/impls/others.py b/python/paddle/tensorrt/impls/others.py index 3aff438e0417bc..2674f0d5cc24b8 100644 --- a/python/paddle/tensorrt/impls/others.py +++ b/python/paddle/tensorrt/impls/others.py @@ -27,6 +27,7 @@ trt_concat, trt_prod, trt_shape, + trt_sub, trt_sum, ) from paddle.tensorrt.register import converter_registry @@ -274,6 +275,112 @@ def share_data_converter(network, paddle_op, inputs): return identity_layer.get_output(0) +@converter_registry.register("pd_op.temporal_shift", trt_version="8.x") +def temporal_shift_converter(network, paddle_op, inputs): + input_tensor = inputs[0] + shift_ratio = paddle_op.attrs()["shift_ratio"] + T = paddle_op.attrs()["seg_num"] + data_format = paddle_op.attrs().get("data_format", "NCHW") + + if data_format == "NHWC": + # Transpose input to [N, C, H, W] + transpose_layer = network.add_shuffle(input_tensor) + transpose_layer.first_transpose = trt.Permutation([0, 3, 1, 2]) + input_tensor = transpose_layer.get_output(0) + + input_dims = input_tensor.shape + C, H, W = input_dims[1], input_dims[2], input_dims[3] + + # Reshape input to [N, T, C, H, W] + reshape_layer = network.add_shuffle(input_tensor) + reshape_layer.reshape_dims = trt.Dims([-1, T, C, H, W]) + input_tensor = reshape_layer.get_output(0) + + # Pad input to [N, T + 2, C, H, W] + pre_pad = add_1D_constant_layer(network, [0, 1, 0, 0, 0]) + post_pad = add_1D_constant_layer(network, [0, 1, 0, 0, 0]) + dims = 5 + zeros = add_1D_constant_layer(network, [0] * dims) + start = trt_sub(network, zeros, pre_pad) + total_padding = trt_sum(network, pre_pad, post_pad) + input_shape = trt_shape(network, input_tensor) + size = trt_sum(network, input_shape, total_padding) + stride = [1] * dims + dummy = stride + + slice_layer = network.add_slice(input_tensor, dummy, dummy, stride) + slice_layer.set_input(1, start) + slice_layer.set_input(2, size) + + trt_version = trt.__version__.split('.') + if int(trt_version[0]) > 8 or ( + int(trt_version[0]) == 8 and int(trt_version[1]) >= 5 + ): + slice_layer.mode = trt.SampleMode.FILL + else: + slice_layer.mode = trt.SliceMode.FILL + + slice_c = int(C * shift_ratio) + slice_c2 = int(C * shift_ratio * 2) + + slice_start1 = zeros + slice_start2 = add_1D_constant_layer(network, [0, 2, slice_c, 0, 0]) + slice_start3 = add_1D_constant_layer(network, [0, 1, slice_c2, 0, 0]) + + slice_size_base = trt_shape(network, input_tensor) + sub_size1 = add_1D_constant_layer(network, [0, 0, C - slice_c, 0, 0]) + sub_size2 = add_1D_constant_layer( + network, [0, 0, C + slice_c - slice_c2, 0, 0] + ) + sub_size3 = add_1D_constant_layer(network, [0, 0, slice_c2, 0, 0]) + + slice_size1 = trt_sub(network, slice_size_base, sub_size1) + slice_size2 = trt_sub(network, slice_size_base, sub_size2) + slice_size3 = trt_sub(network, slice_size_base, sub_size3) + + slice1_layer = network.add_slice( + slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride + ) + slice1_layer.set_input(1, slice_start1) + slice1_layer.set_input(2, slice_size1) + slice2_layer = network.add_slice( + slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride + ) + slice2_layer.set_input(1, slice_start2) + slice2_layer.set_input(2, slice_size2) + slice3_layer = network.add_slice( + slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride + ) + slice3_layer.set_input(1, slice_start3) + slice3_layer.set_input(2, slice_size3) + + concat_inputs = [slice2_layer.get_output(0), slice3_layer.get_output(0)] + if slice_c == 0: + concat_layer = network.add_concatenation(concat_inputs) + concat_layer.axis = 2 + else: + concat_inputs = [ + slice1_layer.get_output(0), + slice2_layer.get_output(0), + slice3_layer.get_output(0), + ] + concat_layer = network.add_concatenation(concat_inputs) + concat_layer.axis = 2 + + # Reshape output to [N*T,C,H,W] + reshape_layer3 = network.add_shuffle(concat_layer.get_output(0)) + reshape_layer3.reshape_dims = trt.Dims([-1, C, H, W]) + + if data_format == "NHWC": + transpose_layer2 = network.add_shuffle(reshape_layer3.get_output(0)) + transpose_layer2.first_transpose = trt.Permutation([0, 2, 3, 1]) + output_tensor = transpose_layer2.get_output(0) + else: + output_tensor = reshape_layer3.get_output(0) + + return output_tensor + + @converter_registry.register("pd_op.anchor_generator", trt_version="8.x") def anchor_generator_converter(network, paddle_op, inputs): inputs = inputs[0] diff --git a/test/tensorrt/test_converter_others.py b/test/tensorrt/test_converter_others.py index 8b201467137eec..23c9db02a705df 100644 --- a/test/tensorrt/test_converter_others.py +++ b/test/tensorrt/test_converter_others.py @@ -406,6 +406,152 @@ def test_trt_result(self): self.check_trt_result() +class TestTemporalShiftTRTPatternBasic(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.nn.functional.temporal_shift + self.api_args = { + "x": np.random.random([4, 9, 7, 7]).astype(np.float32), + "seg_num": 2, + "shift_ratio": 0.2, + "data_format": "NCHW", + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [2, 9, 7, 7]} + self.opt_shape = {"x": [2, 9, 7, 7]} + self.max_shape = {"x": [8, 9, 7, 7]} + + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + def test_trt_result_fp32(self): + self.check_trt_result() + + +class TestTemporalShiftTRTPatternZeroSlice(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.nn.functional.temporal_shift + self.api_args = { + "x": np.random.random([4, 2, 7, 7]).astype(np.float32), + "seg_num": 2, + "shift_ratio": 0.2, + "data_format": "NCHW", + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [2, 2, 7, 7]} + self.opt_shape = {"x": [2, 2, 7, 7]} + self.max_shape = {"x": [8, 2, 7, 7]} + + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + def test_trt_result_fp32(self): + self.check_trt_result() + + +class TestTemporalShiftTRTPatternDifferentSegNum(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.nn.functional.temporal_shift + self.api_args = { + "x": np.random.random([4, 9, 7, 7]).astype(np.float32), + "seg_num": 4, + "shift_ratio": 0.2, + "data_format": "NCHW", + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [4, 9, 7, 7]} + self.opt_shape = {"x": [4, 9, 7, 7]} + self.max_shape = {"x": [8, 9, 7, 7]} + + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + def test_trt_result_fp32(self): + self.check_trt_result() + + +class TestTemporalShiftTRTPatternDifferentShiftRatio(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.nn.functional.temporal_shift + self.api_args = { + "x": np.random.random([4, 9, 7, 7]).astype(np.float32), + "seg_num": 2, + "shift_ratio": 0.4, + "data_format": "NCHW", + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [2, 9, 7, 7]} + self.opt_shape = {"x": [2, 9, 7, 7]} + self.max_shape = {"x": [8, 9, 7, 7]} + + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + def test_trt_result_fp32(self): + self.check_trt_result() + + +class TestTemporalShiftTRTPatternDifferentDataFormat(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.nn.functional.temporal_shift + self.api_args = { + "x": np.random.random([4, 9, 7, 7]).astype(np.float32), + "seg_num": 2, + "shift_ratio": 0.2, + "name": None, + "data_format": "NHWC", + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [2, 9, 7, 7]} + self.opt_shape = {"x": [2, 9, 7, 7]} + self.max_shape = {"x": [8, 9, 7, 7]} + + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + def test_trt_result_fp32(self): + self.check_trt_result() + + +class TestTemporalShiftTRTPatternMinMaxShape(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.nn.functional.temporal_shift + self.api_args = { + "x": np.random.random([4, 9, 7, 7]).astype(np.float32), + "seg_num": 2, + "shift_ratio": 0.2, + "data_format": "NCHW", + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [2, 9, 7, 7]} + self.opt_shape = {"x": [2, 9, 7, 7]} + self.max_shape = {"x": [10, 9, 7, 7]} + + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + def test_trt_result_fp32(self): + self.check_trt_result() + + +def wrapper_temporal_shift(x): + return paddle.nn.functional.temporal_shift(x=x, seg_num=2, shift_ratio=0.2) + + +class TestTemporalShiftTRTPatternError1(TensorRTBaseTest): + def setUp(self): + self.python_api = wrapper_temporal_shift + self.api_args = { + "x": np.random.random([4, 9, 7, 7]).astype(np.float32), + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [2, 9, 7, 7]} + self.opt_shape = {"x": [2, 9, 7, 7]} + self.max_shape = {"x": [10, 9, 7, 7]} + + def test_trt_result(self): + self.check_marker(expected_result=False) + + def affine_channel(x, scale_shape, bias_shape, layout): scale = paddle.static.create_parameter( shape=scale_shape, dtype='float32', name="scale"