Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorRT‌ 10.3 wrong results! #4330

Open
OctaAIVision opened this issue Jan 19, 2025 · 8 comments
Open

TensorRT‌ 10.3 wrong results! #4330

OctaAIVision opened this issue Jan 19, 2025 · 8 comments
Assignees
Labels
Module:Documentation Lack of clarity in documentation triaged Issue has been triaged by maintainers waiting for feedback Requires more information from user to make progress on the issue.

Comments

@OctaAIVision
Copy link

Description

I’m in the process of migrating from TensorRT 8.6 to 10.3. Following the migration guide provided in the documentation, I was able to get inference working on 10.3. However, I’m seeing a significant drop in performance compared to 8.6(getting wrong answers not lower inference time), particularly when dealing with changes on dynamic input shapes.
I am currently working on a two-stage module where the output of the first network serves as the input to the second network. The two networks are connected in sequence.
Has anyone encountered similar issues or could provide guidance on how to handle memory management in TensorRT 10.3 when using dynamic input shapes?

Any help would be greatly appreciated!

Environment

TensorRT Version:10.3.0.30

NVIDIA GPU:NVIDIA‌ Jetson Orin NX (16GB‌ ram),aarch64

NVIDIA Driver Version:Jetpack 6.1

CUDA Version:12.6.68

CUDNN Version: 9.3.0.75

Operating System:Ubunto 22.04

Python Version (if applicable):3.10.12

Relevant Files

common_trt_10.py:

import os
import numpy as np
import tensorrt as trt
from cuda import cuda, cudart
import ctypes
from typing import List, Tuple


def check_cuda_err(err):
    if isinstance(err, cuda.CUresult):
        if err != cuda.CUresult.CUDA_SUCCESS:
            raise RuntimeError("Cuda Error: {}".format(err))
    if isinstance(err, cudart.cudaError_t):
        if err != cudart.cudaError_t.cudaSuccess:
            raise RuntimeError("Cuda Runtime Error: {}".format(err))
    else:
        raise RuntimeError("Unknown error type: {}".format(err))


def cuda_call(call):
    err, res = call[0], call[1:]
    check_cuda_err(err)
    if len(res) == 1:
        res = res[0]
    return res


class HostDeviceMem:
    def __init__(self, size: int, dtype: np.dtype):
        nbytes = size * dtype.itemsize
        host_mem = cuda_call(cudart.cudaMallocHost(nbytes))
        pointer_type = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))

        self._host = np.ctypeslib.as_array(ctypes.cast(host_mem, pointer_type), (size,))
        self._device = cuda_call(cudart.cudaMalloc(nbytes))
        self._nbytes = nbytes

    @property
    def host(self) -> np.ndarray:
        return self._host

    @host.setter
    def host(self, arr: np.ndarray):
        if arr.size > self.host.size:
            raise ValueError(
                f"Tried to fit an array of size {arr.size} into host memory of size {self.host.size}"
            )
        np.copyto(self.host[: arr.size], arr.flat)

    @property
    def device(self) -> int:
        return self._device

    @property
    def nbytes(self) -> int:
        return self._nbytes

    def __str__(self):
        return f"Host:\n{self.host}\nDevice:\n{self.device}\nSize:\n{self.nbytes}\n"

    def __repr__(self):
        return self.__str__()

    def free(self):
        cuda_call(cudart.cudaFree(self.device))
        cuda_call(cudart.cudaFreeHost(self.host.ctypes.data))


def allocate_buffers(engine: trt.ICudaEngine, inputs_shape: List[Tuple[int]]):
    inputs = []
    outputs = []
    bindings = []
    stream = cuda_call(cudart.cudaStreamCreate())
    tensor_names = [engine.get_tensor_name(i) for i in range(engine.num_io_tensors)]
    for shape, binding in zip(inputs_shape, tensor_names):
        size = trt.volume(shape)
        dtype = np.dtype(trt.nptype(engine.get_tensor_dtype(binding)))

        bindingMemory = HostDeviceMem(size, dtype)
        bindings.append(int(bindingMemory.device))
        

        if engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT:
            inputs.append(bindingMemory)
        else:
            outputs.append(bindingMemory)

    return inputs, outputs, bindings, stream


def free_buffers(
    inputs: List[HostDeviceMem],
    outputs: List[HostDeviceMem],
    stream: cudart.cudaStream_t,
):
    for mem in inputs + outputs:
        mem.free()
    cuda_call(cudart.cudaStreamDestroy(stream))


def memcpy_host_to_device(device_ptr: int, host_arr: np.ndarray):
    nbytes = host_arr.size * host_arr.itemsize
    cuda_call(
        cudart.cudaMemcpy(
            device_ptr, host_arr, nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice
        )
    )


def memcpy_device_to_host(host_arr: np.ndarray, device_ptr: int):
    nbytes = host_arr.size * host_arr.itemsize
    cuda_call(
        cudart.cudaMemcpy(
            host_arr, device_ptr, nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost
        )
    )


def _do_inference_base(inputs, outputs, stream, execute_async):
    kind = cudart.cudaMemcpyKind.cudaMemcpyHostToDevice
    [
        cuda_call(
            cudart.cudaMemcpyAsync(inp.device, inp.host, inp.nbytes, kind, stream)
        )
        for inp in inputs
    ]
    execute_async()
    kind = cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost
    [
        cuda_call(
            cudart.cudaMemcpyAsync(out.host, out.device, out.nbytes, kind, stream)
        )
        for out in outputs
    ]
    cuda_call(cudart.cudaStreamSynchronize(stream))
    return [out.host for out in outputs]


def do_inference_v3(context, engine, bindings, inputs, outputs, stream):
    def execute_async():
        # Set tensor addresses before executing
        for i in range(engine.num_io_tensors):
            context.set_tensor_address(engine.get_tensor_name(i), bindings[i])
        context.execute_async_v3(stream_handle=stream)

    return _do_inference_base(inputs, outputs, stream, execute_async)

common_trt_8.py:

import os
import numpy as np
import tensorrt as trt
from cuda import cuda, cudart
import ctypes
from typing import List, Tuple


def check_cuda_err(err):
    if isinstance(err, cuda.CUresult):
        if err != cuda.CUresult.CUDA_SUCCESS:
            raise RuntimeError("Cuda Error: {}".format(err))
    if isinstance(err, cudart.cudaError_t):
        if err != cudart.cudaError_t.cudaSuccess:
            raise RuntimeError("Cuda Runtime Error: {}".format(err))
    else:
        raise RuntimeError("Unknown error type: {}".format(err))


def cuda_call(call):
    err, res = call[0], call[1:]
    check_cuda_err(err)
    if len(res) == 1:
        res = res[0]
    return res


class HostDeviceMem:
    def __init__(self, size: int, dtype: np.dtype):
        nbytes = size * dtype.itemsize
        host_mem = cuda_call(cudart.cudaMallocHost(nbytes))
        pointer_type = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))

        self._host = np.ctypeslib.as_array(ctypes.cast(host_mem, pointer_type), (size,))
        self._device = cuda_call(cudart.cudaMalloc(nbytes))
        self._nbytes = nbytes

    @property
    def host(self) -> np.ndarray:
        return self._host

    @host.setter
    def host(self, arr: np.ndarray):
        if arr.size > self.host.size:
            raise ValueError(
                f"Tried to fit an array of size {arr.size} into host memory of size {self.host.size}"
            )
        np.copyto(self.host[: arr.size], arr.flat)

    @property
    def device(self) -> int:
        return self._device

    @property
    def nbytes(self) -> int:
        return self._nbytes

    def __str__(self):
        return f"Host:\n{self.host}\nDevice:\n{self.device}\nSize:\n{self.nbytes}\n"

    def __repr__(self):
        return self.__str__()

    def free(self):
        cuda_call(cudart.cudaFree(self.device))
        cuda_call(cudart.cudaFreeHost(self.host.ctypes.data))


def allocate_buffers(engine: trt.ICudaEngine, inputs_shape: List[Tuple[int]]):
    inputs = []
    outputs = []
    bindings = []
    stream = cuda_call(cudart.cudaStreamCreate())
    tensor_names = [engine.get_tensor_name(i) for i in range(engine.num_io_tensors)]
    for shape, binding in zip(inputs_shape, tensor_names):
        size = trt.volume(shape)
        dtype = np.dtype(trt.nptype(engine.get_tensor_dtype(binding)))

        bindingMemory = HostDeviceMem(size, dtype)
        bindings.append(int(bindingMemory.device))

        if engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT:
            inputs.append(bindingMemory)
        else:
            outputs.append(bindingMemory)

    return inputs, outputs, bindings, stream


def free_buffers(
    inputs: List[HostDeviceMem],
    outputs: List[HostDeviceMem],
    stream: cudart.cudaStream_t,
):
    for mem in inputs + outputs:
        mem.free()
    cuda_call(cudart.cudaStreamDestroy(stream))


def memcpy_host_to_device(device_ptr: int, host_arr: np.ndarray):
    nbytes = host_arr.size * host_arr.itemsize
    cuda_call(
        cudart.cudaMemcpy(
            device_ptr, host_arr, nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice
        )
    )


def memcpy_device_to_host(host_arr: np.ndarray, device_ptr: int):
    nbytes = host_arr.size * host_arr.itemsize
    cuda_call(
        cudart.cudaMemcpy(
            host_arr, device_ptr, nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost
        )
    )


def _do_inference_base(inputs, outputs, stream, execute_async):
    kind = cudart.cudaMemcpyKind.cudaMemcpyHostToDevice
    [
        cuda_call(
            cudart.cudaMemcpyAsync(inp.device, inp.host, inp.nbytes, kind, stream)
        )
        for inp in inputs
    ]
    execute_async()
    kind = cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost
    [
        cuda_call(
            cudart.cudaMemcpyAsync(out.host, out.device, out.nbytes, kind, stream)
        )
        for out in outputs
    ]
    cuda_call(cudart.cudaStreamSynchronize(stream))
    return [out.host for out in outputs]


def do_inference_v2(context, bindings, inputs, outputs, stream):
    def execute_async():
        context.execute_async_v2(bindings=bindings, stream_handle=stream)

    return _do_inference_base(inputs, outputs, stream, execute_async)

first model inference:

# Allocate buffers based on input shape
        input_name = self.engine.get_tensor_name(0)
        self.context.set_input_shape(input_name, input0.shape)

        model_shapes = [input0.shape, output0.shape, output1.shape]
        self.inputs, self.outputs, self.bindings, self.stream = common.allocate_buffers(self.engine, model_shapes)

        # Run inference on input
        np.copyto(self.inputs[0].host, input0.ravel())
        out = common.do_inference_v3(self.context, self.engine, self.bindings, self.inputs, self.outputs, self.stream)
        out0 = torch.from_numpy(out[0].reshape(output0.shape).clone()
        out1 = torch.from_numpy(out[1].reshape(output1.shape).clone()
        #out0 = torch.from_numpy(self.outputs[0].host.reshape(output0.shape).copy())
        #out1 = torch.from_numpy(self.outputs[1].host.reshape(output1.shape).copy())
        common.free_buffers(self.inputs, self.outputs, self.stream)

second model inference:

        model_shapes = [
            input0.shape,
            input1.shape,
            input2.shape,
            input3.shape,
            input4.shape,
            input5.shape,
            output0.shape,
        ]
        # Set binding for context base on the input shape
        input_binding_index = self.engine.get_tensor_name(0)
        self.context.set_input_shape(input_binding_index, input0.shape)
        input_binding_index = self.engine.get_tensor_name(1)
        self.context.set_input_shape(input_binding_index, input1.shape)
        input_binding_index = self.engine.get_tensor_name(2)
        self.context.set_input_shape(input_binding_index, input2.shape)
        input_binding_index = self.engine.get_tensor_name(3)
        self.context.set_input_shape(input_binding_index, input3.shape)
        input_binding_index = self.engine.get_tensor_name(4)
        self.context.set_input_shape(input_binding_index, input4.shape)
        input_binding_index = self.engine.get_tensor_name(5)
        self.context.set_input_shape(input_binding_index, input5.shape)
        self.inputs, self.outputs, self.bindings, self.stream = common.allocate_buffers(
            self.engine, model_shapes
        )
        
        # Transfer data to Host memory
        np.copyto(self.inputs[0].host, input0.ravel())
        np.copyto(self.inputs[1].host, input1.ravel())
        np.copyto(self.inputs[2].host, input2.ravel())
        np.copyto(self.inputs[3].host, input3.ravel())
        np.copyto(self.inputs[4].host, input4.ravel())
        np.copyto(self.inputs[5].host, input5.ravel())

        # Do inference
        output = common.do_inference_v3(
            self.context, self.engine, self.bindings, self.inputs, self.outputs, self.stream
        )
        # Post proccess the output of model
        out0 = torch.from_numpy(
            self.outputs[0]
            .host.reshape(output0.shape)
            .copy()
        )
        #out0 = torch.from_numpy(output[0].reshape(output0.shape).clone()
        common.free_buffers(self.inputs, self.outputs, self.stream)
@lix19937
Copy link

There was a known accuracy bug that was fixed in 10.5. Can you update your trt version.

@jinhonglu
Copy link

jinhonglu commented Jan 21, 2025

I also face a memory management problem that the results from both onnx and trt models are the same when I do htod_async. but the result is different when I do the dtod_async.

Here is the code

import pycuda.driver as cuda

if NUMPY:
    # fill host memory with flattened input data
    np.copyto(self.inputs[i].host, model_input.ravel())
    [cuda.memcpy_htod_async(inp.device, inp.host, self.stream) for inp in self.inputs]
elif TORCH:
    # for Torch GPU tensor it's easier, can just do Device to Device copy
    cuda.memcpy_dtod_async(self.inputs[i].device, model_input.data_ptr(), model_input.element_size() * model_input.nelement(), self.stream)
self.context.execute_async_v3(stream_handle=self.stream.handle)

I found that passing a numpy data / torch data will get two different results. I am using TensorRT 10.4

Anyone face the same problem?

@lix19937
Copy link

cuda.memcpy_dtod_async(self.inputs[i].device, model_input.data_ptr(), model_input.element_size() * model_input.nelement(), self.stream)

to

    cuda.memcpy_dtod_async(inp.device, model_input.data_ptr(), model_input.element_size() * model_input.nelement(), self.stream)

@jinhonglu
Copy link

jinhonglu commented Jan 22, 2025

cuda.memcpy_dtod_async(self.inputs[i].device, model_input.data_ptr(), model_input.element_size() * model_input.nelement(), self.stream)
to

cuda.memcpy_dtod_async(inp.device, model_input.data_ptr(), model_input.element_size() * model_input.nelement(), self.stream)

Is there any difference?

It is because in the original code, the i index is enumerate from the model_inputs.

More details on the code are shown below.

for i, model_input in enumerate(model_inputs):
    if NUMPY:
    # fill host memory with flattened input data
        np.copyto(self.inputs[i].host, model_input.ravel())
    elif TORCH:
    # for Torch GPU tensor it's easier, can just do Device to Device copy
        cuda.memcpy_dtod_async(self.inputs[i].device, model_input.data_ptr(), model_input.element_size() * model_input.nelement(), self.stream)
if NUMPY:
    [cuda.memcpy_htod_async(inp.device, inp.host, self.stream) for inp in self.inputs]
self.context.execute_async_v3(stream_handle=self.stream.handle)

So, I don't find any difference in changing self.input[i].device to inp.device

@lix19937
Copy link

lix19937 commented Jan 23, 2025

First, make sure the data of model_input.data_ptr() == model_input.ravel() ?

Then, try to use sync api cuda.memcpy_dtod, you can print the inp.device in two case(numpy / torch)

@jinhonglu
Copy link

The problem is resolved by flattening the tensor before getting the data_ptr.

model_input.data_ptr()

to

torch.flatten(model_input).data_ptr()

I suspect it is because the memory address in torch tensor.

@OctaAIVision
Copy link
Author

There was a known accuracy bug that was fixed in 10.5. Can you update your trt version.

Thank you for replying. I upgraded TensorRT on my Jetson device to version 10.7, but unfortunately, I still got the same results. It seems to be an issue related to a dynamic output allocation bug. I tried padding the output of the first model so that it would always serve as a static input for the second model, and this approach returned the correct output.

@kevinch-nv kevinch-nv added Module:Documentation Lack of clarity in documentation triaged Issue has been triaged by maintainers waiting for feedback Requires more information from user to make progress on the issue. labels Feb 10, 2025
@kevinch-nv
Copy link
Collaborator

@OctaAIVision trying to understand the full story here

  • Does engine 1 and engine 2 both have dynamic shapes? What are the profiles that you've set for them?
  • Does engine 1 always produce the correct results?

@kevinch-nv kevinch-nv self-assigned this Feb 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Module:Documentation Lack of clarity in documentation triaged Issue has been triaged by maintainers waiting for feedback Requires more information from user to make progress on the issue.
Projects
None yet
Development

No branches or pull requests

4 participants