Skip to content

Commit

Permalink
Add compile guard to Dropout's FIsCUDAGraphsCompatible def
Browse files Browse the repository at this point in the history
  • Loading branch information
DickJC123 committed Feb 18, 2022
1 parent eaa7fc7 commit c44cfc6
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/operator/nn/dropout.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,18 @@ NNVM_REGISTER_OP(Dropout)
// Dropout is a passthrough during inference for all impls
if (!is_train)
return true;
#if MXNET_USE_CUDNN_DROPOUT
// cuDNN impl is compatible during training as well
const DropoutParam& param =
nnvm::get<DropoutParam>(attrs.parsed);
real_t pkeep = 1.0f - param.p;
bool cudnn_off =
param.cudnn_off && param.cudnn_off.value();
bool cudnn_available = pkeep > 0 && !cudnn_off;
return MXNET_USE_CUDNN_DROPOUT && cudnn_available;
return cudnn_available;
#else
return false;
#endif // MXNET_USE_CUDNN_DROPOUT
})
.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", DropoutCompute<gpu>);

Expand Down

0 comments on commit c44cfc6

Please sign in to comment.