Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deterministic cudnn algorithms #2893

Merged
merged 1 commit into from
Oct 10, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,6 @@
from scipy import stats


@contextlib.contextmanager
def use_cudnn(should_use):
orig = torch.backends.cudnn.enabled
torch.backends.cudnn.enabled = should_use
try:
yield
finally:
torch.backends.cudnn.enabled = orig


def default_tensor_type(type):
type_str = torch.typename(type)

Expand Down Expand Up @@ -1786,6 +1776,25 @@ def test_Conv2d_inconsistent_types_on_GPU_with_cudnn(self):
# but it should work with the same type
nn.functional.conv2d(inputs.float(), weights.float(), bias.float())

@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
@unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
def test_Conv2d_deterministic_cudnn(self):
dtype = torch.cuda.FloatTensor
inputs = Variable(torch.randn(2, 3, 5, 5).type(dtype), requires_grad=True)
with cudnn.flags(enabled=True, benchmark=True, deterministic=True):
conv1 = torch.nn.Conv2d(3, 3, 3).type(dtype)
conv2 = torch.nn.Conv2d(3, 3, 3).type(dtype)
conv2.bias.data.copy_(conv1.bias.data)
conv2.weight.data.copy_(conv1.weight.data)
out1 = conv1(inputs)
out2 = conv2(inputs)
self.assertEqual(out1, out2, prec=0.0)
y = torch.randn(out1.size()).type(dtype)
out1.backward(y)
out2.backward(y)
self.assertEqual(conv1.bias.grad.data, conv2.bias.grad.data, prec=0.0)
self.assertEqual(conv1.weight.grad.data, conv2.weight.grad.data, prec=0.0)

def test_Conv2d_missing_argument(self):
c = nn.Conv2d(3, 3, 3)
self.assertRaises(RuntimeError, lambda: c(None))
Expand Down Expand Up @@ -2973,7 +2982,7 @@ def func(*inputs):
lx, lweight = inputs
lbias = None
# We disable cudnn during forward to avoid finite difference imprecision issues
with use_cudnn(False):
with cudnn.flags(enabled=False):
out = F.conv2d(lx, lweight, lbias, stride, padding, dilation, groups)
return out

Expand All @@ -2997,7 +3006,6 @@ def test_conv_double_backward(self):
for stride, padding, chan_in, chan_out, dilation in \
product([1, 2], [0, 1, 2], [2], [3], dilations):
no_weight = False

result = self.run_conv_double_back_test(kern, stride,
padding, chan_in, chan_out,
batch_size, inp_size, dilation,
Expand Down
19 changes: 19 additions & 0 deletions torch/backends/cudnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
import torch
import warnings
from contextlib import contextmanager

enabled = True # set to False to globally disable cuDNN

Expand Down Expand Up @@ -58,6 +59,7 @@ def is_acceptable(tensor):

_handles = {}

deterministic = False

This comment was marked as off-topic.

This comment was marked as off-topic.

benchmark = False
verbose = False

Expand All @@ -81,6 +83,23 @@ def is_acceptable(tensor):
CUDNN_RNN_ALGO_PERSIST_DYNAMIC = 2


def set_flags(_enabled, _benchmark, _deterministic, _verbose):
global enabled, benchmark, deterministic, verbose
orig_flags = enabled, benchmark, deterministic, verbose
enabled, benchmark, deterministic, verbose = _enabled, _benchmark, _deterministic, _verbose
return orig_flags


@contextmanager
def flags(enabled=False, benchmark=False, deterministic=False, verbose=False):
orig_flags = set_flags(enabled, benchmark, deterministic, verbose)
try:
yield
finally:
# recover the previous values
set_flags(orig_flags[0], orig_flags[1], orig_flags[2], orig_flags[3])

This comment was marked as off-topic.



class CuDNNHandle:
def __init__(self):
ptr = ctypes.c_void_p()
Expand Down
14 changes: 9 additions & 5 deletions torch/csrc/autograd/functions/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ auto ConvParams::use_cudnn(const at::Tensor& input) const -> bool {
if (!input.type().isCuda() || !cudnn_enabled) {
return false;
}
if (deterministic && is_dilated()) {
// cudnn doesn't support deterministic dilated convolution fully yet
return false;
}
if (is_dilated()) {
cudaDeviceProp* prop = THCState_getCurrentDeviceProperties(state);
// NOTE: extra parenthesis around numbers disable clang warnings about dead code
Expand Down Expand Up @@ -251,13 +255,13 @@ auto ConvForward::apply(const variable_list& inputs) -> variable_list {
state, torch::cudnn::getCudnnHandle(), torch::cudnn::getCudnnDataType(input),
(THVoidTensor*)input.unsafeGetTH(false), (THVoidTensor*)weight.unsafeGetTH(false),
bias.defined() ? (THVoidTensor*)bias.unsafeGetTH(false) : nullptr, (THVoidTensor*)output.unsafeGetTH(false),
padding, stride, dilation, groups, benchmark));
padding, stride, dilation, groups, benchmark, deterministic));
} else {
convolution.reset(cudnn_convolution_full_forward(
state, torch::cudnn::getCudnnHandle(), torch::cudnn::getCudnnDataType(input),
(THVoidTensor*)input.unsafeGetTH(false), (THVoidTensor*)weight.unsafeGetTH(false),
bias.defined() ? (THVoidTensor*)bias.unsafeGetTH(false) : nullptr, (THVoidTensor*)output.unsafeGetTH(false),
padding, stride, dilation, groups, benchmark));
padding, stride, dilation, groups, benchmark, deterministic));
}
#endif
} else {
Expand Down Expand Up @@ -345,12 +349,12 @@ auto ConvBackward::apply(const variable_list& grad_outputs) -> variable_list {
cudnn_convolution_forward(
state, torch::cudnn::getCudnnHandle(), torch::cudnn::getCudnnDataType(input),
(THVoidTensor*)grad_output.unsafeGetTH(false), (THVoidTensor*)weight.unsafeGetTH(false), (THVoidTensor*)grad_input.unsafeGetTH(false),
convolution.get(), benchmark);
convolution.get(), benchmark, deterministic);
} else {
cudnn_convolution_backward_data(
state, torch::cudnn::getCudnnHandle(), torch::cudnn::getCudnnDataType(input),
(THVoidTensor*)grad_output.unsafeGetTH(false), (THVoidTensor*)grad_input.unsafeGetTH(false), (THVoidTensor*)weight.unsafeGetTH(false),
convolution.get(), benchmark);
convolution.get(), benchmark, deterministic);
}
#endif
} else if (groups == 1) {
Expand Down Expand Up @@ -379,7 +383,7 @@ auto ConvBackward::apply(const variable_list& grad_outputs) -> variable_list {
cudnn_convolution_backward_filter(
state, torch::cudnn::getCudnnHandle(), torch::cudnn::getCudnnDataType(input),
(THVoidTensor*)grad_output.unsafeGetTH(false), (THVoidTensor*)input.unsafeGetTH(false), (THVoidTensor*)grad_weight.unsafeGetTH(false),
convolution.get(), benchmark);
convolution.get(), benchmark, deterministic);

if (bias.defined() && should_compute_output(2)) {
grad_bias = bias.type().tensor();
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/autograd/functions/convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ struct ConvParams {
std::vector<int> output_padding;
int groups;
bool benchmark;
bool deterministic;

This comment was marked as off-topic.

This comment was marked as off-topic.

bool cudnn_enabled;

bool is_strided() const;
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/autograd/functions/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@ struct ConvCtor {
ConvForward* operator()(PyObject* args) {
ConvParams params;

TupleParser parser(args, 8);
TupleParser parser(args, 9);
parser.parse(params.stride, "stride");
parser.parse(params.padding, "padding");
parser.parse(params.dilation, "dilation");
parser.parse(params.transposed, "transposed");
parser.parse(params.output_padding, "output_padding");
parser.parse(params.groups, "groups");
parser.parse(params.benchmark, "benchmark");
parser.parse(params.deterministic, "deterministic");
parser.parse(params.cudnn_enabled, "cudnn_enabled");

return new ConvForward(std::move(params));
Expand Down
Loading