Skip to content

Commit

Permalink
update ut
Browse files Browse the repository at this point in the history
  • Loading branch information
jiweibo committed Oct 18, 2021
1 parent 8db327f commit 2dc2590
Show file tree
Hide file tree
Showing 10 changed files with 112 additions and 81 deletions.
18 changes: 18 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_inference_pass.h"
#include "paddle/fluid/inference/utils/io_utils.h"
#include "paddle/fluid/inference/utils/singleton.h"
Expand All @@ -56,6 +57,7 @@

#if PADDLE_WITH_TENSORRT
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#endif

Expand Down Expand Up @@ -1470,6 +1472,22 @@ int GetNumBytesOfDataType(DataType dtype) {

std::string GetVersion() { return paddle::get_version(); }

std::tuple<int, int, int> GetTrtCompileVersion() {
#ifdef PADDLE_WITH_TENSORRT
return paddle::inference::tensorrt::GetTrtCompileVersion();
#else
return std::tuple<int, int, int>{0, 0, 0};
#endif
}

std::tuple<int, int, int> GetTrtRuntimeVersion() {
#ifdef PADDLE_WITH_TENSORRT
return paddle::inference::tensorrt::GetTrtRuntimeVersion();
#else
return std::tuple<int, int, int>{0, 0, 0};
#endif
}

std::string UpdateDllFlag(const char *name, const char *value) {
return paddle::UpdateDllFlag(name, value);
}
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,15 @@ TEST(AnalysisPredictor, set_xpu_device_id) {
namespace paddle_infer {

TEST(Predictor, Run) {
auto trt_compile_ver = GetTrtCompileVersion();
auto trt_runtime_ver = GetTrtRuntimeVersion();
LOG(INFO) << "trt compile version: " << std::get<0>(trt_compile_ver) << "."
<< std::get<1>(trt_compile_ver) << "."
<< std::get<2>(trt_compile_ver);
LOG(INFO) << "trt runtime version: " << std::get<0>(trt_runtime_ver) << "."
<< std::get<1>(trt_runtime_ver) << "."
<< std::get<2>(trt_runtime_ver);

Config config;
config.SetModel(FLAGS_dirname);

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/inference/api/paddle_inference_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ PD_INFER_DECL std::shared_ptr<Predictor> CreatePredictor(
PD_INFER_DECL int GetNumBytesOfDataType(DataType dtype);

PD_INFER_DECL std::string GetVersion();
PD_INFER_DECL std::tuple<int, int, int> GetTrtCompileVersion();
PD_INFER_DECL std::tuple<int, int, int> GetTrtRuntimeVersion();
PD_INFER_DECL std::string UpdateDllFlag(const char* name, const char* value);

namespace services {
Expand Down
16 changes: 16 additions & 0 deletions paddle/fluid/inference/tensorrt/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,24 @@ static nvinfer1::IPluginRegistry* GetPluginRegistry() {
static int GetInferLibVersion() {
return static_cast<int>(dy::getInferLibVersion());
}
#else
static int GetInferLibVersion() { return 0; }
#endif

static std::tuple<int, int, int> GetTrtRuntimeVersion() {
int ver = GetInferLibVersion();
int major = ver / 1000;
ver -= major * 1000;
int minor = ver / 100;
int patch = ver - minor * 100;
return std::tuple<int, int, int>{major, minor, patch};
}

static std::tuple<int, int, int> GetTrtCompileVersion() {
return std::tuple<int, int, int>{NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR,
NV_TENSORRT_PATCH};
}

// A logger for create TensorRT infer builder.
class NaiveLogger : public nvinfer1::ILogger {
public:
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ void DynamicShapeTest(bool allow_build_at_runtime) {
engine_op_desc.SetAttr("use_static_engine", true);
engine_op_desc.SetAttr("dynamic_shape_names", std::vector<std::string>{"x"});
engine_op_desc.SetAttr("dynamic_shape_lens", std::vector<int>{4});
engine_op_desc.SetAttr("min_input_shape", std::vector<int>{0, 0, 0, 0});
engine_op_desc.SetAttr("min_input_shape", std::vector<int>{1, 4, 1, 1});
engine_op_desc.SetAttr("max_input_shape", std::vector<int>{2, 4, 1, 1});
engine_op_desc.SetAttr("opt_input_shape", std::vector<int>{2, 4, 1, 1});

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pybind/inference_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,8 @@ void BindInferenceApi(py::module *m) {
m->def("paddle_dtype_size", &paddle::PaddleDtypeSize);
m->def("paddle_tensor_to_bytes", &SerializePDTensorToBytes);
m->def("get_version", &paddle_infer::GetVersion);
m->def("get_trt_compile_version", &paddle_infer::GetTrtCompileVersion);
m->def("get_trt_runtime_version", &paddle_infer::GetTrtRuntimeVersion);
m->def("get_num_bytes_of_data_type", &paddle_infer::GetNumBytesOfDataType);
}

Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@

from .wrapper import Config, DataType, PlaceType, PrecisionType, Tensor, Predictor

from ..core import create_predictor, get_version, get_num_bytes_of_data_type, PredictorPool
from ..core import create_predictor, get_version, get_num_bytes_of_data_type, PredictorPool, get_trt_compile_version, get_trt_runtime_version
59 changes: 59 additions & 0 deletions python/paddle/fluid/tests/unittests/test_inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@

import os, shutil
import unittest
import paddle
paddle.enable_static()
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.core import PaddleTensor
from paddle.fluid.core import PaddleDType
from paddle.inference import Config, Predictor, create_predictor
from paddle.inference import get_trt_compile_version, get_trt_runtime_version


class TestInferenceApi(unittest.TestCase):
Expand Down Expand Up @@ -54,5 +58,60 @@ def test_inference_api(self):
tensor_float.ravel().tolist())


def get_sample_model():
place = fluid.CPUPlace()
exe = fluid.Executor(place)

main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
data = fluid.data(name="data", shape=[-1, 6, 64, 64], dtype="float32")
conv_out = fluid.layers.conv2d(
input=data,
num_filters=3,
filter_size=3,
groups=1,
padding=0,
bias_attr=False,
act=None)
exe.run(startup_program)
serialized_program = paddle.static.serialize_program(
data, conv_out, program=main_program)
serialized_params = paddle.static.serialize_persistables(
data, conv_out, executor=exe, program=main_program)
return serialized_program, serialized_params


class TestInferenceBaseAPI(unittest.TestCase):
def get_config(self, model, params):
config = Config()
config.set_model_buffer(model, len(model), params, len(params))
config.enable_use_gpu(100, 0)
return config

def test_apis(self):
print('trt compile version:', get_trt_compile_version())
print('trt runtime version:', get_trt_runtime_version())
program, params = get_sample_model()
config = self.get_config(program, params)
predictor = create_predictor(config)
in_names = predictor.get_input_names()
in_handle = predictor.get_input_handle(in_names[0])
in_data = np.ones((1, 6, 32, 32)).astype(np.float32)
in_handle.copy_from_cpu(in_data)
predictor.run()

def test_wrong_input(self):
with self.assertRaises(TypeError):
program, params = get_sample_model()
config = self.get_config(program, params)
predictor = create_predictor(config)
in_names = predictor.get_input_names()
in_handle = predictor.get_input_handle(in_names[0])
in_data = np.ones((1, 6, 64, 64)).astype(np.float32)
in_handle.copy_from_cpu(list(in_data))
predictor.run()


if __name__ == '__main__':
unittest.main()
4 changes: 4 additions & 0 deletions python/paddle/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from ..fluid.inference import Predictor # noqa: F401
from ..fluid.inference import create_predictor # noqa: F401
from ..fluid.inference import get_version # noqa: F401
from ..fluid.inference import get_trt_compile_version # noqa: F401
from ..fluid.inference import get_trt_runtime_version # noqa: F401
from ..fluid.inference import get_num_bytes_of_data_type # noqa: F401
from ..fluid.inference import PredictorPool # noqa: F401

Expand All @@ -32,6 +34,8 @@
'Predictor',
'create_predictor',
'get_version',
'get_trt_compile_version',
'get_trt_runtime_version',
'get_num_bytes_of_data_type',
'PredictorPool'
]
79 changes: 0 additions & 79 deletions python/paddle/tests/test_inference_api.py

This file was deleted.

1 comment on commit 2dc2590

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.