Skip to content

Commit

Permalink
add bitwise_and to TRT ElementWise Layer (#59214)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhink authored Nov 23, 2023
1 parent f2f48a5 commit f685b4c
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 2 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2926,6 +2926,7 @@ USE_TRT_CONVERTER(batch_norm);
USE_TRT_CONVERTER(concat);
USE_TRT_CONVERTER(dropout);
USE_TRT_CONVERTER(pad);
USE_TRT_CONVERTER(bitwise_and);
#if IS_TRT_VERSION_GE(8200)
USE_TRT_CONVERTER(pad3d);
USE_TRT_CONVERTER(einsum)
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ list(
rnn_op.cc
fill_constant_batch_size_like_op.cc
sum_op.cc
bitwise_and_op.cc
shape_op.cc
fill_constant_op.cc
fused_token_prune_op.cc
Expand Down
60 changes: 60 additions & 0 deletions paddle/fluid/inference/tensorrt/convert/bitwise_and_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <NvInferRuntimeCommon.h>
#include <cstddef>
#include <iostream>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"

namespace paddle {
namespace inference {
namespace tensorrt {

class BitwiseAndConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(4) << "convert bitwise_and op to tensorrt layer";

framework::OpDesc op_desc(op, nullptr);
nvinfer1::ILayer* layer = nullptr;

auto* input_tensor = engine_->GetITensor(op_desc.Input("X")[0]);
nvinfer1::DataType data_type = input_tensor->getType();

auto* y_tensor = engine_->GetITensor(op_desc.Input("Y")[0]);

// for bool type
if (data_type == nvinfer1::DataType::kBOOL) {
layer = TRT_ENGINE_ADD_LAYER(engine_,
ElementWise,
*input_tensor,
*y_tensor,
nvinfer1::ElementWiseOperation::kAND);
} else {
PADDLE_THROW(platform::errors::Fatal(
"bitwise_and TRT converter is only supported on bool"));
}

auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "bitwise_and", {output_name}, test_mode);
}
};

} // namespace tensorrt
} // namespace inference
} // namespace paddle

REGISTER_TRT_OP_CONVERTER(bitwise_and, BitwiseAndConverter);
35 changes: 33 additions & 2 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1754,6 +1754,35 @@ struct SimpleOpTypeSetTeller : public Teller {
}
}

if (op_type == "bitwise_and") {
#if IS_TRT_VERSION_LT(8400)
VLOG(3) << "bitwise_and is not supported when TensorRT < 8.4";
return false;
#endif
if (!with_dynamic_shape) {
VLOG(3) << "Ops(" << op_type << ") do not support static shape yet.";
return false;
}
auto x_var_name = desc.Input("X")[0];
auto y_var_name = desc.Input("Y")[0];
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto* x_var_desc = block->FindVar(x_var_name);
auto* y_var_desc = block->FindVar(y_var_name);
auto x_dtype = x_var_desc->GetDataType();
auto y_dtype = y_var_desc->GetDataType();
if (x_dtype != framework::proto::VarType::BOOL ||
y_dtype != framework::proto::VarType::BOOL) {
VLOG(3) << "the bitwise_and only support input of BOOL.";
return false;
}
}

if (op_type == "pad3d") {
#if !IS_TRT_VERSION_GE(8200)
VLOG(3) << "pad3d is not supported when TensorRT < 8.2";
Expand Down Expand Up @@ -2914,7 +2943,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"flip",
"quantize_linear",
"dequantize_linear",
"share_data"};
"share_data",
"bitwise_and"};

std::unordered_set<std::string> teller_set{
"matrix_multiply",
Expand Down Expand Up @@ -3083,7 +3113,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"flip",
"quantize_linear",
"dequantize_linear",
"share_data"};
"share_data",
"bitwise_and"};
};

struct GenericPluginTeller : public Teller {
Expand Down
153 changes: 153 additions & 0 deletions test/ir/inference/test_trt_convert_bitwise_and.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from functools import partial
from typing import List

import numpy as np
from program_config import ProgramConfig, TensorConfig
from trt_layer_auto_scan_test import TrtLayerAutoScanTest

import paddle.inference as paddle_infer


class TrtConvertBitwiseAndTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True

def sample_program_configs(self):
def generate_input(batch):
if self.dims == 4:
return np.random.random([batch, 3, 3, 24]).astype(np.int32)
elif self.dims == 3:
return np.random.random([batch, 3, 24]).astype(np.bool8)
elif self.dims == 2:
return np.random.random([batch, 24]).astype(np.bool_)

for dims in [2, 3, 4]:
for batch in [3, 6, 9]:
self.dims = dims
ops_config = [
{
"op_type": "bitwise_and",
"op_inputs": {
"X": ["input_data1"],
"Y": ["input_data2"],
},
"op_outputs": {"Out": ["output_data"]},
"op_attrs": {},
},
]
ops = self.generate_op_config(ops_config)

program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data1": TensorConfig(
data_gen=partial(generate_input, batch)
),
"input_data2": TensorConfig(
data_gen=partial(generate_input, batch)
),
},
outputs=["output_data"],
)

yield program_config

def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
if self.dims == 4:
self.dynamic_shape.min_input_shape = {
"input_data1": [1, 3 - 1, 3 - 1, 24 - 1],
"input_data2": [1, 3 - 1, 3 - 1, 24 - 1],
}
self.dynamic_shape.max_input_shape = {
"input_data1": [9, 3 + 1, 3 + 1, 24 + 1],
"input_data2": [9, 3 + 1, 3 + 1, 24 + 1],
}
self.dynamic_shape.opt_input_shape = {
"input_data1": [1, 3, 3, 24],
"input_data2": [1, 3, 3, 24],
}
elif self.dims == 3:
self.dynamic_shape.min_input_shape = {
"input_data1": [1, 3 - 1, 24 - 1],
"input_data2": [1, 3 - 1, 24 - 1],
}
self.dynamic_shape.max_input_shape = {
"input_data1": [9, 3 + 1, 24 + 1],
"input_data2": [9, 3 + 1, 24 + 1],
}
self.dynamic_shape.opt_input_shape = {
"input_data1": [1, 3, 24],
"input_data2": [1, 3, 24],
}
elif self.dims == 2:
self.dynamic_shape.min_input_shape = {
"input_data1": [1, 24],
"input_data2": [1, 24],
}
self.dynamic_shape.max_input_shape = {
"input_data1": [9, 24],
"input_data2": [9, 24],
}
self.dynamic_shape.opt_input_shape = {
"input_data1": [1, 24],
"input_data2": [1, 24],
}

def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}

def generate_trt_nodes_num(attrs, dynamic_shape):
ver = paddle_infer.get_trt_compile_version()
trt_version = ver[0] * 1000 + ver[1] * 100 + ver[2] * 10
if trt_version < 8400:
return 0, 4
if self.dims == 4 or self.dims == 1:
return 0, 4
return 1, 3

attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
self.trt_param.max_batch_size = 9
self.trt_param.workspace_size = 1073741824

# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
program_config.set_input_type(np.float32)
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
program_config.set_input_type(np.float16)
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), 1e-3

def test(self):
self.run_test()


if __name__ == "__main__":
unittest.main()

0 comments on commit f685b4c

Please sign in to comment.