Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【SCU】【Paddle TensorRT No.57】Add pd_op.temporal_shift converter #69848

Merged
merged 26 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

temporalshift的kernel已经不支持x.size()!=4的输入,并且加上这个条件过不了覆盖率单测,因此去掉这个条件

Original file line number Diff line number Diff line change
Expand Up @@ -2094,6 +2094,34 @@ class AssignValueOpPattern
}
};

class TemporalShiftOpPattern
: public pir::OpRewritePattern<paddle::dialect::TemporalShiftOp> {
public:
using pir::OpRewritePattern<
paddle::dialect::TemporalShiftOp>::OpRewritePattern;

bool MatchAndRewrite(paddle::dialect::TemporalShiftOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op.attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
if (!op->HasAttribute("shift_ratio") || !op->HasAttribute("seg_num")) {
VLOG(3) << "temporal shift need attributes : shift_ratio and seg_num";
return false;
}
auto x = op.operand_source(0);
auto x_shape = pir::GetShapeFromValue(x);
if (x_shape.size() != 4) {
VLOG(3) << "The input and grid tensors must be shape tensors of rank 4 "
"when using TRT TemporalShift layer.";
return false;
}
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};

class TrtOpMarkerPass : public pir::PatternRewritePass {
public:
TrtOpMarkerPass() : pir::PatternRewritePass("trt_op_marker_pass", 2) {}
Expand Down Expand Up @@ -2207,6 +2235,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ps.Add(std::make_unique<OneHotOpPattern>(context));
ps.Add(std::make_unique<AssignValueOpPattern>(context));
ps.Add(std::make_unique<AssignValue_OpPattern>(context));
ps.Add(std::make_unique<TemporalShiftOpPattern>(context));
return ps;
}
};
Expand Down
107 changes: 107 additions & 0 deletions python/paddle/tensorrt/impls/others.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
trt_concat,
trt_prod,
trt_shape,
trt_sub,
trt_sum,
)
from paddle.tensorrt.register import converter_registry
Expand Down Expand Up @@ -299,3 +300,109 @@ def share_data_converter(network, paddle_op, inputs):
identity_layer = network.add_identity(x)

return identity_layer.get_output(0)


@converter_registry.register("pd_op.temporal_shift", trt_version="8.x")
def temporal_shift_converter(network, paddle_op, inputs):
input_tensor = inputs[0]
shift_ratio = paddle_op.attrs().get("shift_ratio")
T = paddle_op.attrs().get("seg_num")
data_format = paddle_op.attrs().get("data_format", "NCHW")

if data_format == "NHWC":
# Transpose input to [N, C, H, W]
transpose_layer = network.add_shuffle(input_tensor)
transpose_layer.first_transpose = trt.Permutation([0, 3, 1, 2])
input_tensor = transpose_layer.get_output(0)

input_dims = input_tensor.shape
C, H, W = input_dims[1], input_dims[2], input_dims[3]

# Reshape input to [N, T, C, H, W]
reshape_layer = network.add_shuffle(input_tensor)
reshape_layer.reshape_dims = trt.Dims([-1, T, C, H, W])
input_tensor = reshape_layer.get_output(0)

# Pad input to [N, T + 2, C, H, W]
pre_pad = add_1D_constant_layer(network, [0, 1, 0, 0, 0])
post_pad = add_1D_constant_layer(network, [0, 1, 0, 0, 0])
dims = 5
zeros = add_1D_constant_layer(network, [0] * dims)
start = trt_sum(network, zeros, pre_pad)
PolaKuma marked this conversation as resolved.
Show resolved Hide resolved
total_padding = trt_sum(network, pre_pad, post_pad)
input_shape = trt_shape(network, input_tensor)
size = trt_sum(network, input_shape, total_padding)
stride = [1] * dims
dummy = stride

slice_layer = network.add_slice(input_tensor, dummy, dummy, stride)
slice_layer.set_input(1, start)
slice_layer.set_input(2, size)

trt_version = trt.__version__.split('.')
if int(trt_version[0]) > 8 or (
int(trt_version[0]) == 8 and int(trt_version[1]) >= 5
):
slice_layer.mode = trt.SampleMode.FILL
else:
slice_layer.mode = trt.SliceMode.FILL

slice_c = int(C * shift_ratio)
slice_c2 = int(C * shift_ratio * 2)

slice_start1 = zeros
slice_start2 = add_1D_constant_layer(network, [0, 2, slice_c, 0, 0])
slice_start3 = add_1D_constant_layer(network, [0, 1, slice_c2, 0, 0])

slice_size_base = trt_shape(network, input_tensor)
sub_size1 = add_1D_constant_layer(network, [0, 0, C - slice_c, 0, 0])
sub_size2 = add_1D_constant_layer(
network, [0, 0, C + slice_c - slice_c2, 0, 0]
)
sub_size3 = add_1D_constant_layer(network, [0, 0, slice_c2, 0, 0])

slice_size1 = trt_sub(network, slice_size_base, sub_size1)
slice_size2 = trt_sub(network, slice_size_base, sub_size2)
slice_size3 = trt_sub(network, slice_size_base, sub_size3)

slice1_layer = network.add_slice(
slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride
)
slice1_layer.set_input(1, slice_start1)
slice1_layer.set_input(2, slice_size1)
slice2_layer = network.add_slice(
slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride
)
slice2_layer.set_input(1, slice_start2)
slice2_layer.set_input(2, slice_size2)
slice3_layer = network.add_slice(
slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride
)
slice3_layer.set_input(1, slice_start3)
slice3_layer.set_input(2, slice_size3)

if slice_c == 0:
concat_inputs = [slice2_layer.get_output(0), slice3_layer.get_output(0)]
PolaKuma marked this conversation as resolved.
Show resolved Hide resolved
concat_layer = network.add_concatenation(concat_inputs)
concat_layer.axis = 2
else:
concat_inputs = [
slice1_layer.get_output(0),
slice2_layer.get_output(0),
slice3_layer.get_output(0),
]
concat_layer = network.add_concatenation(concat_inputs)
concat_layer.axis = 2

# Reshape output to [N*T,C,H,W]
reshape_layer = network.add_shuffle(concat_layer.get_output(0))
PolaKuma marked this conversation as resolved.
Show resolved Hide resolved
reshape_layer.reshape_dims = trt.Dims(inputs[0].shape)

if data_format == "NHWC":
transpose_layer = network.add_shuffle(reshape_layer.get_output(0))
PolaKuma marked this conversation as resolved.
Show resolved Hide resolved
transpose_layer.first_transpose = trt.Permutation([0, 2, 3, 1])
output_tensor = transpose_layer.get_output(0)
else:
output_tensor = reshape_layer.get_output(0)

return output_tensor
85 changes: 85 additions & 0 deletions test/tensorrt/test_converter_others.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,5 +394,90 @@ def test_trt_result(self):
self.check_trt_result()


class TestTemporalShiftTRTPatternBasic(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.nn.functional.temporal_shift
self.api_args = {
"x": np.random.random([4, 9, 7, 7]).astype(np.float32),
"seg_num": 2,
"shift_ratio": 0.2,
"data_format": "NCHW",
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [2, 9, 7, 7]}
self.max_shape = {"x": [8, 9, 7, 7]}

def test_trt_result(self):
self.check_trt_result()


class TestTemporalShiftTRTPatternDifferentSegNum(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.nn.functional.temporal_shift
self.api_args = {
"x": np.random.random([4, 9, 7, 7]).astype(np.float32),
"seg_num": 4,
"shift_ratio": 0.2,
"data_format": "NCHW",
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [4, 9, 7, 7]}
self.max_shape = {"x": [8, 9, 7, 7]}

def test_trt_result(self):
PolaKuma marked this conversation as resolved.
Show resolved Hide resolved
self.check_trt_result()


class TestTemporalShiftTRTPatternDifferentShiftRatio(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.nn.functional.temporal_shift
self.api_args = {
"x": np.random.random([4, 9, 7, 7]).astype(np.float32),
"seg_num": 2,
"shift_ratio": 0.4,
"data_format": "NCHW",
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [2, 9, 7, 7]}
self.max_shape = {"x": [8, 9, 7, 7]}

def test_trt_result(self):
self.check_trt_result()


class TestTemporalShiftTRTPatternDifferentDataFormat(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.nn.functional.temporal_shift
self.api_args = {
"x": np.random.random([4, 9, 7, 7]).astype(np.float32),
"seg_num": 2,
"shift_ratio": 0.2,
"data_format": "NHWC",
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [2, 9, 7, 7]}
self.max_shape = {"x": [8, 9, 7, 7]}

def test_trt_result(self):
self.check_trt_result()


class TestTemporalShiftTRTPatternMinMaxShape(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.nn.functional.temporal_shift
self.api_args = {
"x": np.random.random([4, 9, 7, 7]).astype(np.float32),
"seg_num": 2,
"shift_ratio": 0.2,
"data_format": "NCHW",
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [2, 9, 7, 7]}
self.max_shape = {"x": [10, 9, 7, 7]}

def test_trt_result(self):
self.check_trt_result()


if __name__ == '__main__':
unittest.main()