From ceef616f2caa9b3189d472f3713e774fdb594db4 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Sat, 31 Oct 2020 06:21:48 -0700 Subject: [PATCH] Extract channels from weight shape for conv2d. (#6805) --- src/runtime/contrib/tensorrt/tensorrt_ops.cc | 2 +- tests/python/contrib/test_tensorrt.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc b/src/runtime/contrib/tensorrt/tensorrt_ops.cc index 4c5eeea1e644..a86f107941bc 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_ops.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc @@ -242,7 +242,7 @@ class Conv2DOpConverter : public TensorRTOpConverter { auto str_dilation = params->node.GetAttr>("dilation"); auto str_padding = params->node.GetAttr>("padding"); int groups = std::stoi(params->node.GetAttr>("groups")[0]); - int channels = std::stoi(params->node.GetAttr>("channels")[0]); + int channels = weight_shape[0]; // TRT conv2d op doesn't support asymmetric padding before 5.1, so we // workaround by adding a padding layer before the pooling op. nvinfer1::DimsHW prepadding, postpadding; diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index 6f615397db58..9faf51f397f3 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -251,7 +251,6 @@ def get_graph( out = relay.nn.conv2d( x, kernel, - channels=k_shape[0], kernel_size=k_shape[2:4], groups=groups, padding=padding,