Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic batchsize onnx model Strangely results in fixed batchsize output model with trtexec #996

Closed
handoku opened this issue Jan 6, 2021 · 7 comments
Labels
triaged Issue has been triaged by maintainers

Comments

@handoku
Copy link

handoku commented Jan 6, 2021

Description

I have exported a onnx model from pytorch,some code snippet like:

np.random.seed(0)
x = np.random.randn(1,3,256,256).astype(np.float32)
x = torch.from_numpy(x).cuda()

torch_out = model(x)
print(torch_out)

dynamic_axes={'INPUT__0' : {0:'batch_size'}, 'OUTPUT__0' : {0 : 'batch_size'}, 'OUTPUT__1':{0:'batch_size'}}
torch.onnx.export(model, x, "cdcn.onnx", opset_version=11, 
    export_params=True, keep_initializers_as_inputs=True, do_constant_folding=True,  
    input_names = ['INPUT__0'], 
    output_names = ['OUTPUT__0', 'OUTPUT__1'], 
    dynamic_axes = dynamic_axes
    )   
print("onnx saved! ")

The onnx file looks ok in Netron, then I produced a trt engine with trtexec.

However, when load the plan file with triton inference server, it outputs:

I0106 09:18:18.800655 48551 autofill.cc:213] TensorRT autofill: OK: 
W0106 09:18:18.800690 48551 autofill.cc:165] The TRT engine doesn't specify appropriate dimensions to support dynamic batching
I0106 09:18:18.800721 48551 model_config_utils.cc:276] autofilled config: name: "cdcn_trt"
platform: "tensorrt_plan"
input {
  name: "INPUT__0"
  data_type: TYPE_FP32
  dims: -1
  dims: 3
  dims: 256
  dims: 256
}
output {
  name: "OUTPUT__0"
  data_type: TYPE_FP32
  dims: 5
  dims: 2
}
output {
  name: "OUTPUT__1"
  data_type: TYPE_FP32
  dims: 5
  dims: 2
}
instance_group {
  count: 2
  gpus: 0
  kind: KIND_GPU
}
default_model_filename: "model.plan"

I have no idea how the dims 5 came out.

Could you help me out for this

Environment

TensorRT Version: 7.0.0
GPU Type: Tesla-T4
Nvidia Driver Version: 418.xx
CUDA Version: 10.1
Operating System + Version: ubuntu 18.04

Relevant Files

onnx file : https://drive.google.com/file/d/1hzFuOPQtN0ivOsP6IOi6h4gw3-0NPH0d/view?usp=sharing

Steps To Reproduce

./bin/trtexec --onnx=./cdcn.onnx --explicitBatch --minShapes='INPUT__0':1x3x256x256 --optShapes='INPUT__0':4x3x256x256 --maxShapes='INPUT__0':8x3x256x256 --buildOnly --saveEngine=./cdcn.plan --workspace=11288 --device=0

then load the model with triton inference server

@handoku
Copy link
Author

handoku commented Jan 7, 2021

I found that after Pytorch's interpolate with bilinear mode and align_corner=true,the resulted trt engine becomes a fixed batchsize model.
Although I have check the the onnx-tensorrt paser, the Resize layer isDynamic(layer->getOutput(0)->getDimensions()) returns true;

It make it impossible to create a trt plan file which support dynamic batching.

A very simple example could be :

class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel,self).__init__()

    def forward(self, x): 
        out = F.interpolate(x, size=(32,32), mode = 'bilinear',align_corners=True)
        return out 

torch_input = torch.from_numpy(np.random.randn(1,3,64,64).astype(np.float32)).cuda()
model = MyModel()

torch.onnx.export( model, torch_input, onnx_model_file, verbose=False, 
        export_params=True,
        input_names=['input'], output_names=['out'], opset_version = 11, keep_initializers_as_inputs=True,
        dynamic_axes={"input" : {0:"bs" }, "out":{0:"bs"}})

then create a plan file with trtexec :

./bin/trtexec --onnx=./resize.onnx --explicitBatch --minShapes=\'input\':1x3x64x64 --optShapes=\'input\':4x3x64x64 --maxShapes=\'input\':8x3x64x64  --buildOnly --saveEngine=./resize.plan --workspace=11288

when load the trt model with trtserver, it outputs:

I0107 13:40:31.901351 48889 autofill.cc:213] TensorRT autofill: OK: 
W0107 13:40:31.901367 48889 autofill.cc:165] The TRT engine doesn't specify appropriate dimensions to support dynamic batching
I0107 13:40:31.901388 48889 model_config_utils.cc:276] autofilled config: name: "cdcn_trt"
platform: "tensorrt_plan"
input {
  name: "input"
  data_type: TYPE_FP32
  dims: -1
  dims: 3
  dims: 64
  dims: 64
}
output {
  name: "out"
  data_type: TYPE_FP32
  dims: 1
  dims: 3
  dims: 32
  dims: 32
}
instance_group {
  count: 2
  gpus: 0
  kind: KIND_GPU
}
default_model_filename: "model.plan"

@handoku
Copy link
Author

handoku commented Jan 8, 2021

Probably, it's a bug in tensorrt's Resize op or the builder module.

I hava test on tensorrt 7.2.1.4 with python sdk, code snippet:

import tensorrt as trt

my_trt_logger = trt.Logger(trt.Logger.WARNING)
trt_runtime = trt.Runtime(my_trt_logger)
with open('resize.plan', 'rb') as f:
    trt_engine = trt_runtime.deserialize_cuda_engine(f.read())

print('input shape : ', trt_engine.get_binding_shape(0))
print('out shape : ', trt_engine.get_binding_shape(1))

output:

input shape :  (-1, 3, 64, 64)
out shape :  (1, 3, 32, 32)

@ttyio
Copy link
Collaborator

ttyio commented Feb 23, 2021

Hello @handoku , thanks for reporting.
the shape reporting here is wrong from your output, could you try release 7.2? We have several dynamic shape related fix in this build. Thanks!

@ttyio ttyio added Release: 7.x triaged Issue has been triaged by maintainers labels Feb 23, 2021
@handoku
Copy link
Author

handoku commented Feb 28, 2021

@ttyio
I have already tested with release 7.2, it doesn't work.
BTW,according to the tensorrt docs, the cause of this problem is that the Resize op in tensorrt does not support broadcasting across batch_size dimension, so it can't support dynamic batch_size.

@ttyio
Copy link
Collaborator

ttyio commented Mar 1, 2021

Hello @handoku ,

broadcast and dynamic shapes are orthogonal features, though sometimes they break each other. BTW, where did you see the broadcast limitation on resize op?

I have just tried 7.2.1.6,

generate onnx using your instructions and generate trt engine using command line:

      ./trtexec --onnx=./output.onnx --explicitBatch --minShapes='input':1x3x64x64 --optShapes='input':2x3x64x64 --maxShapes='input':8x3x64x64  --buildOnly --saveEngine=./output.engine

then using the script you provide, I see output:

       input shape :  (-1, 3, 64, 64)
       out shape :  (-1, 3, 32, 32)

Could you take a try? thanks!

@handoku
Copy link
Author

handoku commented Mar 1, 2021

@ttyio
I was testing within docker image ngc-tensorrt-20.10, tensorrt version is 7.2.1.4. I am really sure that my output is what I have posted, while yours seems correct now.

Anyway, thanks for your test, I may have another try with 7.2.1.6.

Docs link : support matrix

image

@ttyio
Copy link
Collaborator

ttyio commented May 26, 2021

closing since no activity for more than 3 weeks, please reopen if you still have question, thanks!

@ttyio ttyio closed this as completed May 26, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

2 participants