Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[blockChooser.cpp::getRegionBlockSize::690] Error Code 2: Internal Error (Assertion memSize >= 0 failed. ) #2045

Closed
AllentDan opened this issue Jun 9, 2022 · 15 comments
Assignees
Labels
Enhancement New feature or request triaged Issue has been triaged by maintainers

Comments

@AllentDan
Copy link

Description

Encounter the error as follows:

[blockChooser.cpp::getRegionBlockSize::690] Error Code 2: Internal Error (Assertion memSize >= 0 failed. )

Environment

TensorRT Version: 8+
NVIDIA GPU: 1660
NVIDIA Driver Version: 470
CUDA Version: 11.3
CUDNN Version: compatible with cuda 11.3
Operating System: linux x86
Python Version (if applicable): 3.8
PyTorch Version (if applicable): 1.10

Steps To Reproduce

import torch
import onnx
import tensorrt as trt

onnx_model = 'model.onnx'


class NaiveModel(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        B, C, H, W = x.shape

        pad_w = W % 7
        pad_h = H % 7
        x_t = torch.zeros((B, C, H + pad_h, W + pad_w), device=x.device)
        x_t[:, :, :H, :W] = x
        return x_t


device = torch.device('cuda:0')

# generate ONNX model
torch.onnx.export(
    NaiveModel(),
    torch.randn(1, 3, 224, 224),
    onnx_model,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes=dict(
        input=dict({
            0: 'batch',
            2: 'height',
            3: 'width'
        }),
        output=dict({0: 'batch'})),
    opset_version=11)
onnx_model = onnx.load(onnx_model)

# load_tensorrt_plugin()
# create builder and network
logger = trt.Logger(trt.Logger.ERROR)
builder = trt.Builder(logger)
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(EXPLICIT_BATCH)

# parse onnx
parser = trt.OnnxParser(network, logger)

if not parser.parse(onnx_model.SerializeToString()):
    error_msgs = ''
    for error in range(parser.num_errors):
        error_msgs += f'{parser.get_error(error)}\n'
    raise RuntimeError(f'Failed to parse onnx, {error_msgs}')

config = builder.create_builder_config()
config.max_workspace_size = 1 << 30
profile = builder.create_optimization_profile()

profile.set_shape('input', [1, 3, 112, 112], [1, 3, 224, 224],
                  [1, 3, 512, 512])
config.add_optimization_profile(profile)
# create engine
with torch.cuda.device(device):
    engine = builder.build_engine(network, config)

with open('model.engine', mode='wb') as f:
    f.write(bytearray(engine.serialize()))
    print("generating file done!")
@AllentDan
Copy link
Author

I can avoid the above error by replacing

        x_t = torch.zeros((B, C, H + pad_h, W + pad_w), device=x.device)
        x_t[:, :, :H, :W] = x
        return x_t

to

        x = torch.cat((x, torch.zeros((B, C, pad_h, W), device=x.device)), 2)
        x = torch.cat(
            (x, torch.zeros((B, C, x.shape[-2], pad_w), device=x.device)), -1)
        return x

However, if I put another shape transformation line before return x, I still get the above error.

        x = torch.cat((x, torch.zeros((B, C, pad_h, W), device=x.device)), 2)
        x = torch.cat(
            (x, torch.zeros((B, C, x.shape[-2], pad_w), device=x.device)), -1)
        x = x.reshape(-1, x.shape[1], x.shape[2] * x.shape[3])
        return x

@zerollzeng
Copy link
Collaborator

you torch script generate a very complicate onnx model, and what you want to do is just a padding

image

can you try use ZeroPad2d?

        # x_t = torch.zeros((B, C, H + pad_h, W + pad_w), device=x.device)
        # x_t[:, :, :H, :W] = x
        pad = torch.nn.ZeroPad2d((0, pad_w, 0, pad_h))
        x_t = pad(x)

@zerollzeng zerollzeng added the triaged Issue has been triaged by maintainers label Jun 9, 2022
@zerollzeng zerollzeng self-assigned this Jun 9, 2022
@zerollzeng
Copy link
Collaborator

zerollzeng commented Jun 9, 2022

but there is one problem here, TensorRT doesn't support the mod operator. does your model has an unknown H and W during inference?

@AllentDan
Copy link
Author

AllentDan commented Jun 9, 2022

Yes, just want to do dynamic inference. It is actually used in swin-transformer.

@zerollzeng
Copy link
Collaborator

Looks like there is no way we can avoid the mod operator here, how about making the pad_w and pad_h also the network inputs? and computing it outside the network.

@AllentDan
Copy link
Author

AllentDan commented Jun 9, 2022

Looks like there is no way we can avoid the mod operator here, how about making the pad_w and pad_h also the network inputs? and computing it outside the network.

Well, it is hard to control the input pad_w or pad_h outside the network. Check the original Pytorch codes here out. There are nested blocks.

@AllentDan
Copy link
Author

AllentDan commented Jun 9, 2022

So, if the mod operator raises the error, maybe writing a plugin is the workaround? Just want to make sure if the mod operator raises it. And if I include the mod part inside the plugin, there should be no shape-related issue then?

@zerollzeng
Copy link
Collaborator

Yes, plugin should work.

@AllentDan
Copy link
Author

Hi, @zerollzeng. I tried another way to avoid the mod operator here:

class NaiveModel(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        B, C, H, W = x.shape

        pad_h = H - (H // 7) * 7
        x = torch.cat((x, torch.zeros((B, C, pad_h, W), device=x.device)), 2)
        x = x.reshape(-1, x.shape[1], x.shape[2] * x.shape[3])
        return x

But again the above error was triggered.

@zerollzeng
Copy link
Collaborator

How about

class NaiveModel(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        B, C, H, W = x.shape

        pad_h = H - (H // 7) * 7
        pad_w = W - (W // 7) * 7
        pad = torch.nn.ZeroPad2d((0, pad_w, 0, pad_h))
        x_t = pad(x)
        return x_t

will it work?

@AllentDan
Copy link
Author

How about

class NaiveModel(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        B, C, H, W = x.shape

        pad_h = H - (H // 7) * 7
        pad_w = W - (W // 7) * 7
        pad = torch.nn.ZeroPad2d((0, pad_w, 0, pad_h))
        x_t = pad(x)
        return x_t

will it work?

It failed. Besides, it is not padding failed as I mentioned above. Only padding can be converted to TRT while it failed once we append another reshape.

@zerollzeng
Copy link
Collaborator

Yes,
for model

class NaiveModel(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        B, C, H, W = x.shape

        pad_h = H - (H // 7) * 7
        pad_w = W - (W // 7) * 7
        pad = torch.nn.ZeroPad2d((0, pad_w, 0, pad_h))
        x = pad(x)
        x = x.reshape(B, C, (H + pad_h) * (W + pad_w))
        return x

I can see the error

[06/09/2022-15:34:34] [E] [TRT] ModelImporter.cpp:775: input: "input"
input: "onnx::Pad_59"
input: "onnx::Pad_60"
output: "onnx::Reshape_61"
name: "Pad_51"
op_type: "Pad"
attribute {
  name: "mode"
  s: "constant"
  type: STRING
}

[06/09/2022-15:34:34] [E] [TRT] ModelImporter.cpp:776: --- End node ---
[06/09/2022-15:34:34] [E] [TRT] ModelImporter.cpp:779: ERROR: ModelImporter.cpp:180 In function parseGraph:
[6] Invalid Node - Pad_51
[shuffleNode.cpp::symbolicExecute::392] Error Code 4: Internal Error (Reshape_40: IShuffleLayer applied to shape tensor must have 0 or 1 reshape dimensions: dimensions were [-1,2])
[06/09/2022-15:34:34] [E] Failed to parse onnx file
[06/09/2022-15:34:34] [I] Finish parsing network model
[06/09/2022-15:34:34] [E] Parsing model failed
[06/09/2022-15:34:34] [E] Failed to create engine from model or file.
[06/09/2022-15:34:34] [E] Engine set up failed
&&&& FAILED TensorRT.trtexec [TensorRT v8401] # /usr/src/tensorrt/bin/trtexec --onnx=model.onnx --optShapes=input:1x3x224x224 --shapes=input:1x3x224x224

@jackwish @nvpohanh Do we support ND shape tensor in 8.4?

@grimoire
Copy link

A workaround is to replace the reshape in prepare_onnx_paddings with the Concat of two Slice with step=2 (begins and ends) of paddings.

    # paddings = sym_help._reshape_helper(
    #     g, paddings, g.op("Constant", value_t=torch.tensor([-1, 2])))
    # paddings = g.op(
    #     "Transpose",
    #     torch.onnx.symbolic_opset10.flip(g, paddings, [0]),
    #     perm_i=[1, 0])
    # paddings = sym_help._reshape_helper(
    #     g, paddings, g.op("Constant", value_t=torch.tensor([-1])))


    paddings = torch.onnx.symbolic_opset10.flip(g, paddings, [0])
    begins = sym_help._slice_helper(
        g, paddings, axes=[0], starts=[1], ends=[0xffff], steps=[2])
    ends = sym_help._slice_helper(
        g, paddings, axes=[0], starts=[0], ends=[0xffff], steps=[2])
    paddings = g.op('Concat', begins, ends, axis_i=0)

@nvpohanh
Copy link
Collaborator

Do we support ND shape tensor in 8.4?

ND shape tensor will be supported in the next version after TRT 8.4.

@ttyio
Copy link
Collaborator

ttyio commented Nov 1, 2022

Closing since no activity for more than 3 weeks, please reopen if you still have question, thanks!

@ttyio ttyio closed this as completed Nov 1, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Enhancement New feature or request triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

5 participants