Skip to content

Commit

Permalink
Conv node bug, cached state was incoherent (microsoft#10041)
Browse files Browse the repository at this point in the history
* Moved the init earlier to keep the cache coherent
* Move setting of w_desc later, and zero shape check later to catch all cacheable changes.
* Add comment
  • Loading branch information
RyanUnderhill authored Mar 1, 2022
1 parent f4b2d3a commit c1cf16e
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions onnxruntime/core/providers/cuda/nn/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,6 @@ Status Conv<T>::UpdateState(OpKernelContext* context, bool bias_expected) const
s_.slice_axes = slice_axes;

s_.Y = context->Output(0, TensorShape(s_.y_dims));
if (s_.Y->Shape().Size() == 0) {
return Status::OK();
}
if (post_slicing_required) {
// Post slicing needed. Create and fill in the Conv results in an intermediate buffer.
s_.memory_for_cudnn_conv_results = GetScratchBuffer<void>(TensorShape(y_dims_with_adjusted_pads).Size() * s_.element_size);
Expand Down Expand Up @@ -206,9 +203,14 @@ Status Conv<T>::UpdateState(OpKernelContext* context, bool bias_expected) const
dilations.push_back(1);
}

if (w_dims_changed) {
if (w_dims_changed)
ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType<CudaT>()));

// We must delay returning early until here so that the weight dims have been cached properly
if (s_.Y->Shape().Size() == 0) {
return Status::OK();
}

ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType<CudaT>()));
ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType<CudaT>()));
ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
Expand Down

0 comments on commit c1cf16e

Please sign in to comment.