Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use FP16 ot INT8? #32

Closed
ChengshuLi opened this issue Jul 18, 2018 · 22 comments
Closed

How to use FP16 ot INT8? #32

ChengshuLi opened this issue Jul 18, 2018 · 22 comments
Labels
question Further information is requested triaged Issue has been triaged by maintainers

Comments

@ChengshuLi
Copy link

ChengshuLi commented Jul 18, 2018

Hi,

I was trying to use FP16 and INT8.

I understand this is how you prepare a FP32 model.

model = onnx.load("/path/to/model.onnx")
engine = backend.prepare(model, device='CUDA:1')
input_data = np.random.random(size=(32, 3, 224, 224)).astype(np.float32)

I tried this, but it didn't work.

model = onnx.load("/path/to/model.onnx")
engine = backend.prepare(model, device='CUDA:1', dtype=np.float16)
input_data = np.random.random(size=(32, 3, 224, 224)).astype(np.float16)

Any help will be greatly appreciated. Thanks!

@yinghai

@yinghai
Copy link

yinghai commented Jul 18, 2018

I don't think it's fully supported right now. What's the error message?

@ChengshuLi
Copy link
Author

ChengshuLi commented Jul 19, 2018

Here is my output when I ran:
onnx2trt my_model.onnx -o my_engine.trt -d 16

----------------------------------------------------------------
Input filename:   drn.onnx
ONNX IR version:  0.0.3
Opset version:    6
Producer name:    pytorch
Producer version: 0.3
Domain:           
Model version:    0
Doc string:       
----------------------------------------------------------------
Parsing model
Building TensorRT engine, FP16 available:0
    Max batch size:     32
    Max workspace size: 1024 MiB

The output model is the roughly the same as the Float32 one. No speed gain. Model size stays the same.

After further investigation, I noticed that

bool fp16 = trt_builder->platformHasFastFp16();
is false for me. I think that's why it did not give me a FP16 version of the model.

I am pretty new to TensorRT so this might be a stupid question. But does this depend on my GPU? I am currently using Titan X (Pascal).

Thanks a lot @yinghai

@ChengshuLi
Copy link
Author

ChengshuLi commented Jul 19, 2018

I tried to use Titan V, which supports FP16 I believe. Notice that FP16 available changes from 0 to 1. However, it gave me an error when I ran onnx2trt my_model.onnx -o my_engine.trt -d 16 :

----------------------------------------------------------------
Input filename:   drn.onnx
ONNX IR version:  0.0.3
Opset version:    6
Producer name:    pytorch
Producer version: 0.3
Domain:           
Model version:    0
Doc string:       
----------------------------------------------------------------
Parsing model
Building TensorRT engine, FP16 available:1
    Max batch size:     32
    Max workspace size: 1024 MiB
[2018-07-19 02:45:33   ERROR] reformat.cu (1369) - Cuda Error in NCHWToNHWC: 9
[2018-07-19 02:45:33   ERROR] reformat.cu (1369) - Cuda Error in NCHWToNHWC: 9
terminate called after throwing an instance of 'std::runtime_error'
  what():  Failed to create object
fish: “onnx2trt drn.onnx -o drn_float1…” terminated by signal SIGABRT (Abort)

FYI, onnx2trt my_model.onnx -o my_engine.trt using Float32 works fine.

@yinghai
Copy link

yinghai commented Jul 19, 2018

Is it possible to share a minimally reproducible model with us? It doesn't necessarily need to be the real model and weights can be randomized, as long as it can help us reproduce the issue.

@ChengshuLi
Copy link
Author

ChengshuLi commented Jul 19, 2018

@yinghai Thanks a lot for helping out.

I will try to walk you through what I have done:

The original semantic segmentation pytorch model is here
The inference time for this pytorch model is ~0.17s / image on Titan X. I tried to use TensorRT to accelerate it.

First of all, I converted the Pytorch model to ONNX model using the code below because I didn't want to create the model definition in the TensorRT format from scratch.

import torch
from segment import DRNSeg

weights = 'drn-d-105_ms_cityscapes.pth'
model = DRNSeg('drn_d_105', 19, pretrained_model=None, pretrained=False)
model.load_state_dict(torch.load(weights))
dummy_input = torch.randn(1, 3, 512, 1024, requires_grad=False)
torch.onnx.export(model, dummy_input, "drn.onnx")

The converted ONNX model is here.

Then I ran onnx2trt drn.onnx -o drn.trt successfully and got drn.trt. I followed NVIDIA's TensorRT tutorial to load the engine (i.e. drn.trt file) and to run inference on a sample image.

The segmentation result looks correct, which is why I believe the entire conversion process is correct (pytorch -> ONNX -> TensorRT engine trt file). However, the running speed is slower than before at ~0.33s / image on Titan X. It ran even slower at ~0.42s / image on Titan V. I didn't quite get why TensorRT will slow things down. Do you have some ideas?

Therefore, I wanted to try FP16 to further speed up the model. On Titan X, the conversion onnx2trt went through but the model was still using FP32 because

bool fp16 = trt_builder->platformHasFastFp16();
is false.

On Titan V, the conversion onnx2trt failed and here is the error message again.

----------------------------------------------------------------
Input filename:   drn.onnx
ONNX IR version:  0.0.3
Opset version:    6
Producer name:    pytorch
Producer version: 0.3
Domain:           
Model version:    0
Doc string:       
----------------------------------------------------------------
Parsing model
Building TensorRT engine, FP16 available:1
    Max batch size:     32
    Max workspace size: 1024 MiB
[2018-07-19 02:45:33   ERROR] reformat.cu (1369) - Cuda Error in NCHWToNHWC: 9
[2018-07-19 02:45:33   ERROR] reformat.cu (1369) - Cuda Error in NCHWToNHWC: 9
terminate called after throwing an instance of 'std::runtime_error'
  what():  Failed to create object
fish: “onnx2trt drn.onnx -o drn_float1…” terminated by signal SIGABRT (Abort)

Sorry for the long message. I hope this clarifies my situation. Thanks again for your generous help.

@yinghai
Copy link

yinghai commented Jul 20, 2018

However, the running speed is slower than before at ~0.33s / image on Titan X. It ran even slower at ~0.42s / image on Titan V. I didn't quite get why TensorRT will slow things down. Do you have some ideas?

This sounds strange. How did you run the TRT engine?

@ChengshuLi
Copy link
Author

ChengshuLi commented Jul 20, 2018

Here is my code modified from TensorRT example code. Sorry it's a bit long. Please let me know if you see any problem with my code.

import sys

try:
    from PIL import Image
except ImportError as err:
    raise ImportError("""ERROR: Failed to import module ({})
Please make sure you have Pillow installed.
For installation instructions, see:
http://pillow.readthedocs.io/en/stable/installation.html""".format(err))

try:
    from tensorrt.parsers import onnxparser
    import tensorrt as trt
    import pycuda.driver as cuda
    import pycuda.autoinit
    import argparse
    import numpy as np
except ImportError as err:
    raise ImportError("""ERROR: Failed to import module ({})
Please make sure you have pycuda and the example dependencies installed.
https://wiki.tiker.net/PyCuda/Installation/Linux
pip(3) install tensorrt[examples]""".format(err))

# Logger
G_LOGGER = trt.infer.ConsoleLogger(trt.infer.LogSeverity.INFO)

def get_input_output_names(trt_engine):
    nbindings = trt_engine.get_nb_bindings();
    maps = {}
    for b in range(0, nbindings):
        dims = trt_engine.get_binding_dimensions(b).to_DimsCHW()
        if (trt_engine.binding_is_input(b)):
            print("Found input: ")
            print(trt_engine.get_binding_name(b))
            print("shape=" + str(dims.C()) + " , " + str(dims.H()) + " , " + str(dims.W()))
            print("dtype=" + str(trt_engine.get_binding_data_type(b)))
            maps["input"] = trt_engine.get_binding_name(b)
        else:
            print("Found output: ")
            print(trt_engine.get_binding_name(b))
            print("shape=" + str(dims.C()) + " , " + str(dims.H()) + " , " + str(dims.W()))
            print("dtype=" + str(trt_engine.get_binding_data_type(b)))
            output_name = "output1" if dims.H() == 512 and dims.W() == 1024 else "output2"
            maps[output_name] = trt_engine.get_binding_name(b)
    return maps

def normalize_data(data, inp_dims):
    in_size = inp_dims.C() * inp_dims.H() * inp_dims.W()
    for s in range(0, in_size):
        data[s] = data[s] / 255
    return data

def read_ascii_file(input_file, size):
    ret = []
    for line in open(input_file, 'r'):
        ret += line.split()

    ret = np.array(ret, np.float32)
    assert(ret.size == size)
    return ret

def prepare_input(input_file, trt_engine, file_format):
    in_out = get_input_output_names(trt_engine)
    input_indx = trt_engine.get_binding_index(in_out["input"])
    inp_dims = trt_engine.get_binding_dimensions(input_indx).to_DimsCHW()
    if (file_format=="ascii"):
        img = read_ascii_file(input_file, inp_dims.C() * inp_dims.H() * inp_dims.W())
    elif (file_format == "ppm"):
        img = preprocess_image(input_file, inp_dims)
    else:
        print("Not supported format")
        sys.exit()
    return img

def process_output(output):
    output = output.argmax(axis=0)
    output = output.astype(np.int8)
    output = Image.fromarray(output)
    output.save("result.png")

def preprocess_image(image_path, inp_dims):
    ppm_image = Image.open(image_path)
    # resize image
    new_h = inp_dims.H()
    new_w = inp_dims.W()
    size = (new_w, new_h)
    # resize image
    img = ppm_image.resize(size, Image.BILINEAR)
    # convert to numpy array
    img = np.array(img)
    # hwc2chw
    img = img.transpose(2, 0, 1)
    # convert image to 1D array
    img = img.ravel()
    # convert image to float
    img = img.astype(np.float32)
    # normalize image data
    img = normalize_data(img, inp_dims)
    return img

def inference_image(context, input_img, batch_size):
    # load engine
    trt_engine = context.get_engine()

    in_out = get_input_output_names(trt_engine)
    input_indx = trt_engine.get_binding_index(in_out["input"])
    inp_dims = trt_engine.get_binding_dimensions(input_indx).to_DimsCHW()
    output1_indx = trt_engine.get_binding_index(in_out["output1"])
    out1_dims = trt_engine.get_binding_dimensions(output1_indx).to_DimsCHW()
    output2_indx = trt_engine.get_binding_index(in_out["output2"])
    out2_dims = trt_engine.get_binding_dimensions(output2_indx).to_DimsCHW()

    print("input dims:", inp_dims.C(), inp_dims.H(), inp_dims.W())
    print("output1 dims:", out1_dims.C(), out1_dims.H(), out1_dims.W())
    print("output2 dims:", out2_dims.C(), out2_dims.H(), out2_dims.W())

    # create output array
    output1 = np.empty((out1_dims.C(), out1_dims.H(), out1_dims.W()), dtype=np.float32)
    output2 = np.empty((out2_dims.C(), out2_dims.H(), out2_dims.W()), dtype=np.float32)
    # allocate device memory
    d_input = cuda.mem_alloc(batch_size * input_img.size * input_img.dtype.itemsize)
    d_output1 = cuda.mem_alloc(batch_size * output1.size * output1.dtype.itemsize)
    d_output2 = cuda.mem_alloc(batch_size * output2.size * output2.dtype.itemsize)
    # create input/output bindings
    bindings = [int(d_input), int(d_output1), int(d_output2)]
    stream = cuda.Stream()
    # transfer input data to device
    cuda.memcpy_htod_async(d_input, input_img, stream)
    # execute model
    import time
    for i in range(100):
        start = time.time()
        context.enqueue(batch_size, bindings, stream.handle, None)
        print(time.time() - start)
    # transfer predictions
    cuda.memcpy_dtoh_async(output1, d_output1, stream)
    # synchronize threads
    stream.synchronize()
    return output1

def sample_onnx_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--image_file", type=str, required=True, help="Path to the image file")
    parser.add_argument("-m", "--model_file", type=str, required=True, help="ONNX Model file")
    parser.add_argument("-d", "--data_type", default=32, type=int, choices=[8, 16, 32], help="Supported data type i.e. 8, 16, 32 bit")
    parser.add_argument("-b", "--max_batch_size", default=32, type=int, help="Maximum batch size")
    parser.add_argument("-w", "--max_workspace_size", default=1024*1024, type=int, help="Maximum workspace size")
    parser.add_argument("-v", "--add_verbosity", action="store_true")
    parser.add_argument("-q", "--reduce_verbosity", action="store_true")
    parser.add_argument("-l", "--print_layer_info", action="store_true")
    args = parser.parse_args()

    image_file = str.strip(args.image_file)
    model_file = str.strip(args.model_file)
    max_batch_size = args.max_batch_size
    max_workspace_size = args.max_workspace_size
    data_type = args.data_type
    add_verbosity = args.add_verbosity
    reduce_verbosity = args.reduce_verbosity
    print_layer_info = args.print_layer_info

    print("Input Arguments: ")
    print("model_file", model_file)
    print("data_type", data_type)
    print("max_workspace_size", max_workspace_size)
    print("max_batch_size", max_batch_size)
    print("add_verbosity", add_verbosity)
    print("reduce_verbosity", reduce_verbosity)
    print("print_layer_info", print_layer_info)

    # set batch size
    batch_size = 1

    # load engine directly from model.trt file
    trt_engine = trt.utils.load_engine(G_LOGGER, model_file)

    # create input vector
    file_format = 'ppm'
    input_img = prepare_input(image_file, trt_engine, file_format)
    print(input_img.shape)

    if input_img.size == 0:
        msg = "sampleONNX the input tensor is of zero size - please check your path to the input or the file type"
        G_LOGGER.log(trt.infer.Logger.Severity_kERROR, msg)

    trt_context = trt_engine.create_execution_context()
    output = inference_image(trt_context, input_img, batch_size)

    # post processing stage
    process_output(output)

    # clean up
    trt_context.destroy()
    trt_engine.destroy()
    print("&&&& PASSED Onnx Parser Tested Successfully")

if __name__=="__main__":
    sample_onnx_parser()

Specifically I measure the speed by doing:

    for i in range(100):
        start = time.time()
        context.enqueue(batch_size, bindings, stream.handle, None)
        print(time.time() - start)

Is it correct?

You can run the code by doing:

python run_inference.py -m drn.trt -i sample.png

@yinghai
Copy link

yinghai commented Jul 24, 2018

Time measure doesn't seem to be correct as enquque is just to submit a task to CUDA stream. Is your image size always the same? If so, you can avoid allocating memory for each inference.

@ChengshuLi
Copy link
Author

@yinghai Thanks for your info.

Yes the image size is always the same. What do you think is the best way to measure inference time? Any suggestion?

Also, any idea on why FP16 conversion fails on Titan V?

Thanks!

@yinghai
Copy link

yinghai commented Jul 24, 2018

You should measure it after stream.synchronize().

@ChengshuLi
Copy link
Author

ChengshuLi commented Jul 25, 2018

@yinghai

Thanks! Unfortunately the result is the same. The inference time for FP32 TensorRT engine is still ~0.42s / image on Titan V whereas the original pytorch model runs at ~0.17s / image.

I am still unable to convert the model to FP16.

One possibility I can think of is that cuda.memcpy_htod_async(d_input, input_img, stream) takes too much time. Since my input image is large (512 * 1024), I am not sure how much time it takes to do the memcpy. I don't know how to measure the time since this function seems to be async as indicated by its name. Any suggestion for how to benchmark this?

UPDATE: I commented out cuda.memcpy_htod_async and cuda.memcpy_dtoh_async. The speed remains the same.

@Faldict
Copy link

Faldict commented Jul 26, 2018

@ChengshuLi I've got stuck in the same problem, while I used MXNet models. Here's my code to measure TensorRT's inference time.

ts = time.time()
stream = cuda.Stream()
cuda.memcpy_htod_async(d_input, data, stream)
context.enqueue(batch_size, bindings, stream.handle, None)
cuda.memcpy_dtoh_async(output, d_output, stream)
stream.synchronize()
te = time.time()

And it seems that TensorRT even retarded the inference.

@yinghai
Copy link

yinghai commented Jul 26, 2018

Looks like we can try something like this https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/c_api/classnvinfer1_1_1_i_profiler.html to profile the TensorRT engine. Does it have a python binding?

@Faldict
Copy link

Faldict commented Jul 26, 2018

@yinghai I guess you may refer to this.

@poorneshwaran
Copy link

poorneshwaran commented May 13, 2019

@yinghai Thanks a lot for helping out.

I will try to walk you through what I have done:

The original semantic segmentation pytorch model is here
The inference time for this pytorch model is ~0.17s / image on Titan X. I tried to use TensorRT to accelerate it.

First of all, I converted the Pytorch model to ONNX model using the code below because I didn't want to create the model definition in the TensorRT format from scratch.

import torch
from segment import DRNSeg

weights = 'drn-d-105_ms_cityscapes.pth'
model = DRNSeg('drn_d_105', 19, pretrained_model=None, pretrained=False)
model.load_state_dict(torch.load(weights))
dummy_input = torch.randn(1, 3, 512, 1024, requires_grad=False)
torch.onnx.export(model, dummy_input, "drn.onnx")

The converted ONNX model is here.

Then I ran onnx2trt drn.onnx -o drn.trt successfully and got drn.trt. I followed NVIDIA's TensorRT tutorial to load the engine (i.e. drn.trt file) and to run inference on a sample image.

The segmentation result looks correct, which is why I believe the entire conversion process is correct (pytorch -> ONNX -> TensorRT engine trt file). However, the running speed is slower than before at ~0.33s / image on Titan X. It ran even slower at ~0.42s / image on Titan V. I didn't quite get why TensorRT will slow things down. Do you have some ideas?

Therefore, I wanted to try FP16 to further speed up the model. On Titan X, the conversion onnx2trt went through but the model was still using FP32 because

bool fp16 = trt_builder->platformHasFastFp16();

is false.

On Titan V, the conversion onnx2trt failed and here is the error message again.

----------------------------------------------------------------
Input filename:   drn.onnx
ONNX IR version:  0.0.3
Opset version:    6
Producer name:    pytorch
Producer version: 0.3
Domain:           
Model version:    0
Doc string:       
----------------------------------------------------------------
Parsing model
Building TensorRT engine, FP16 available:1
    Max batch size:     32
    Max workspace size: 1024 MiB
[2018-07-19 02:45:33   ERROR] reformat.cu (1369) - Cuda Error in NCHWToNHWC: 9
[2018-07-19 02:45:33   ERROR] reformat.cu (1369) - Cuda Error in NCHWToNHWC: 9
terminate called after throwing an instance of 'std::runtime_error'
  what():  Failed to create object
fish: “onnx2trt drn.onnx -o drn_float1…” terminated by signal SIGABRT (Abort)

Sorry for the long message. I hope this clarifies my situation. Thanks again for your generous help.

I'm trying to use your method. And i got error like

/onnx/onnx/onnx_onnx2trt_onnx .pb.h:12:2: error: #error This file was generated by a newer version of protoc w hich is #error This file was generated by a newer version of protoc which is
while make -j8. I think , this is protoc version error. May i know what version did you use and please share for protoc installation.

@poorneshwaran
Copy link

@yinghai Thanks a lot for helping out.
I will try to walk you through what I have done:
The original semantic segmentation pytorch model is here
The inference time for this pytorch model is ~0.17s / image on Titan X. I tried to use TensorRT to accelerate it.
First of all, I converted the Pytorch model to ONNX model using the code below because I didn't want to create the model definition in the TensorRT format from scratch.

import torch
from segment import DRNSeg

weights = 'drn-d-105_ms_cityscapes.pth'
model = DRNSeg('drn_d_105', 19, pretrained_model=None, pretrained=False)
model.load_state_dict(torch.load(weights))
dummy_input = torch.randn(1, 3, 512, 1024, requires_grad=False)
torch.onnx.export(model, dummy_input, "drn.onnx")

The converted ONNX model is here.
Then I ran onnx2trt drn.onnx -o drn.trt successfully and got drn.trt. I followed NVIDIA's TensorRT tutorial to load the engine (i.e. drn.trt file) and to run inference on a sample image.
The segmentation result looks correct, which is why I believe the entire conversion process is correct (pytorch -> ONNX -> TensorRT engine trt file). However, the running speed is slower than before at ~0.33s / image on Titan X. It ran even slower at ~0.42s / image on Titan V. I didn't quite get why TensorRT will slow things down. Do you have some ideas?
Therefore, I wanted to try FP16 to further speed up the model. On Titan X, the conversion onnx2trt went through but the model was still using FP32 because

bool fp16 = trt_builder->platformHasFastFp16();

is false.
On Titan V, the conversion onnx2trt failed and here is the error message again.

----------------------------------------------------------------
Input filename:   drn.onnx
ONNX IR version:  0.0.3
Opset version:    6
Producer name:    pytorch
Producer version: 0.3
Domain:           
Model version:    0
Doc string:       
----------------------------------------------------------------
Parsing model
Building TensorRT engine, FP16 available:1
    Max batch size:     32
    Max workspace size: 1024 MiB
[2018-07-19 02:45:33   ERROR] reformat.cu (1369) - Cuda Error in NCHWToNHWC: 9
[2018-07-19 02:45:33   ERROR] reformat.cu (1369) - Cuda Error in NCHWToNHWC: 9
terminate called after throwing an instance of 'std::runtime_error'
  what():  Failed to create object
fish: “onnx2trt drn.onnx -o drn_float1…” terminated by signal SIGABRT (Abort)

Sorry for the long message. I hope this clarifies my situation. Thanks again for your generous help.

I'm trying to use your method. And i got error like

/onnx/onnx/onnx_onnx2trt_onnx .pb.h:12:2: error: #error This file was generated by a newer version of protoc w hich is #error This file was generated by a newer version of protoc which is
while make -j8. I think , this is protoc version error. May i know what version did you use and please share for protoc installation.

I Solved the Installation Problem ..

Now I followed your code & I got error like segmentation fault (core dumped) on tensorrt 5.1.2

can you please help me with that..

@liuchang8am
Copy link

any updates?
Same issue here, no speed gain from Pytorch->ONNX->TRT

@poorneshwaran
Copy link

@yinghai Thanks a lot for helping out.
I will try to walk you through what I have done:
The original semantic segmentation pytorch model is here
The inference time for this pytorch model is ~0.17s / image on Titan X. I tried to use TensorRT to accelerate it.
First of all, I converted the Pytorch model to ONNX model using the code below because I didn't want to create the model definition in the TensorRT format from scratch.

import torch
from segment import DRNSeg

weights = 'drn-d-105_ms_cityscapes.pth'
model = DRNSeg('drn_d_105', 19, pretrained_model=None, pretrained=False)
model.load_state_dict(torch.load(weights))
dummy_input = torch.randn(1, 3, 512, 1024, requires_grad=False)
torch.onnx.export(model, dummy_input, "drn.onnx")

The converted ONNX model is here.
Then I ran onnx2trt drn.onnx -o drn.trt successfully and got drn.trt. I followed NVIDIA's TensorRT tutorial to load the engine (i.e. drn.trt file) and to run inference on a sample image.
The segmentation result looks correct, which is why I believe the entire conversion process is correct (pytorch -> ONNX -> TensorRT engine trt file). However, the running speed is slower than before at ~0.33s / image on Titan X. It ran even slower at ~0.42s / image on Titan V. I didn't quite get why TensorRT will slow things down. Do you have some ideas?
Therefore, I wanted to try FP16 to further speed up the model. On Titan X, the conversion onnx2trt went through but the model was still using FP32 because

bool fp16 = trt_builder->platformHasFastFp16();

is false.
On Titan V, the conversion onnx2trt failed and here is the error message again.

----------------------------------------------------------------
Input filename:   drn.onnx
ONNX IR version:  0.0.3
Opset version:    6
Producer name:    pytorch
Producer version: 0.3
Domain:           
Model version:    0
Doc string:       
----------------------------------------------------------------
Parsing model
Building TensorRT engine, FP16 available:1
    Max batch size:     32
    Max workspace size: 1024 MiB
[2018-07-19 02:45:33   ERROR] reformat.cu (1369) - Cuda Error in NCHWToNHWC: 9
[2018-07-19 02:45:33   ERROR] reformat.cu (1369) - Cuda Error in NCHWToNHWC: 9
terminate called after throwing an instance of 'std::runtime_error'
  what():  Failed to create object
fish: “onnx2trt drn.onnx -o drn_float1…” terminated by signal SIGABRT (Abort)

Sorry for the long message. I hope this clarifies my situation. Thanks again for your generous help.

I'm trying to use your method. And i got error like
/onnx/onnx/onnx_onnx2trt_onnx .pb.h:12:2: error: #error This file was generated by a newer version of protoc w hich is #error This file was generated by a newer version of protoc which is
while make -j8. I think , this is protoc version error. May i know what version did you use and please share for protoc installation.

I Solved the Installation Problem ..

Now I followed your code & I got error like segmentation fault (core dumped) on tensorrt 5.1.2

can you please help me with that..

generally segmentation fault (core dumped) relates to memory allocation problem. I provide full access to my path and file and memory's are free. Even though, it replicate the same error. please update if any one solved or ideas relate to this problem.

@ShawnNew
Copy link

Seems that bool fp16 = trt_builder->platformHasFastFp16(); will always end up false, which tells that the platform has no fp16 supported.
Anyone knows how to add the functionality of fp16 to the platform?

Many thanks.

@oelgendy
Copy link

FP16 inference is 10x slower than FP32!
Hi,
I am doing inference with Onnxruntime in C++. I converted the ONNX file into FP16 in Python using onnxmltools convert_float_to_float16. I obtain the fp16 tensor from libtorch tensor, and wrap it in an onnx fp16 tensor using
g_ort->CreateTensorWithDataAsOrtValue(memory_info, libtorchTensor.data_ptr(), input_tensor_size * 2, input_node_dims.data(), input_node_dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, &onnxTensor)
What am I missing?
Thanks,
-Omar

@kevinch-nv kevinch-nv added question Further information is requested triaged Issue has been triaged by maintainers labels Oct 25, 2020
@kevinch-nv
Copy link
Collaborator

Coming late to this thread, so I'll try my best to answer the questions posed by multiple users:

  • For later versions of TensorRT, we recommend using the trtexec tool we have to convert ONNX models to TRT engines over onnx2trt (we're planning on deprecating onnx2trt soon)
  • To use mixed precision with TensorRT, you'll have to specify the corresponding --fp16 or --int8 flags for trtexec to build in your specified precision
  • If trt_builder->platformHasFastFp16() returns false, that means the GPU on your system does not support FP16 operations. You can refer to this list for the list of cards that support different precisions
  • ONNX-Runtime inference is different from TensorRT inference. Please re-run your benchmarks with trtexec with the --fp16 flag specified.

Closing this thread, if anyone needs an update on their specific issue feel free to open a new issue.

@Lenan22
Copy link

Lenan22 commented Oct 31, 2022

Please refer to our open source quantization tool ppq, we can help you solve quantization problems
https://github.com/openppl-public/ppq/blob/master/md_doc/deploy_trt_by_OnnxParser.md

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

9 participants