-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathzoneout.py
157 lines (144 loc) · 6.94 KB
/
zoneout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from torch import nn, torch
import torch.nn.functional as F
from torch.autograd.function import InplaceFunction
from torch.autograd import Variable
class Zoneout(InplaceFunction):
@staticmethod
def _make_mask(input):
return input.new().resize_as_(input).expand_as(input)
@classmethod
def forward(cls, ctx, current_input, previous_input, p=None, mask=None,
train=False, inplace=False):
if p is None and mask is None:
raise ValueError('Either p or mask must be provided')
if p is not None and mask is not None:
raise ValueError('Only one of p and mask can be provided')
if p is not None and (p < 0 or p > 1):
raise ValueError('zoneout probability has to be between 0 and 1, '
'but got {}'.format(p))
if mask is not None and \
not isinstance(mask, torch.ByteTensor) and \
not isinstance(mask, torch.cuda.ByteTensor):
raise ValueError("mask must be a ByteTensor")
if current_input.size() != previous_input.size():
raise ValueError(
'Current and previous inputs must be of the same '
'size, but current has size {current} and '
'previous has size {previous}.'.format(
current='x'.join(str(size) for size in current_input.size()),
previous='x'.join(str(size) for size in previous_input.size()))
)
if type(current_input) != type(previous_input):
raise ValueError('Current and previous inputs must be of the same '
'type, but current is {current} and previous is '
'{previous}'.format(current=type(current_input),
previous=type(previous_input))
)
ctx.p = p
ctx.train = train
ctx.inplace = inplace
if ctx.inplace:
ctx.mark_dirty(current_input)
output = current_input
else:
output = current_input.clone()
if not ctx.train:
return output
ctx.current_mask = cls._make_mask(current_input)
ctx.previous_mask = cls._make_mask(current_input)
if mask is None:
ctx.current_mask.bernoulli_(1 - ctx.p)
else:
ctx.current_mask.fill_(0).masked_fill_(mask, 1)
ctx.previous_mask.fill_(1).sub_(ctx.current_mask)
output.mul_(ctx.current_mask).add_(previous_input.mul(ctx.previous_mask))
return output
@staticmethod
def backward(ctx, grad_output):
if ctx.train:
return grad_output * Variable(ctx.current_mask), grad_output * Variable(ctx.previous_mask), \
None, None, None, None
else:
return grad_output, None, None, None, None, None
class Zoneout(nn.Module):
r"""During training of an RNN, randomly swaps some of the elements of the
input tensor with its values from a previous time-step with probability *p*
using samples from a Bernoulli distribution. The elements to be swapped are
randomized on every time-step by default, but a shared mask can be
provided.
Zoneout is a variant of dropout designed specifically for regularizing
recurrent connections of LSTMs or GRUs. While dropout applies a zero mask
to its inputs, zoneout applies an identity mask when incrementing a
time-step.
It has proven to be an effective technique for regularization of LSTMs
and GRUs as, contrary to dropout, gradient information and state
information are more readily propagated through time. For further
information, consult the paper
`Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activation`_ .
Similarly to dropout, during evaluation the module simply computes an
identity function.
Args:
p: probability of an element to be zeroed. Default: None.
inplace: If set to ``True``, will do this operation in-place.
Default: ``False``
mask: `ByteTensor`. A mask used to select elements to be swapped.
The intended use case for this argument is sharing a zoneout mask
across several time-steps.
Shape:
- Input: `Any`. A pair of tensors of the same shape
- Output: `Same`. Output is of the same shape as input
Examples::
>>> zoneout = nn.Zoneout(p=0.15)
>>> current_hidden_state = Variable(torch.Tensor([1, 2, 3]))
>>> previous_hidden_state = Variable(torch.Tensor([4, 5, 6]))
>>> output = zoneout(current_hidden_state, previous_hidden_state)
Using a shared mask:
>>> mask = torch.ByteTensor(1, 3).bernoulli()
>>> zoneout = nn.Zoneout(mask=mask)
>>> current_hidden_state = Variable(torch.Tensor([1, 2, 3]))
>>> previous_hidden_state = Variable(torch.Tensor([4, 5, 6]))
>>> output = zoneout(current_hidden_state, previous_hidden_state)
Wrapping around a `GRUCell`:
>>> rnn = nn.GRUCell(10, 20)
>>> input = Variable(torch.randn(6, 3, 10))
>>> h = Variable(torch.randn(3, 20))
>>> h_prev = Variable(torch.randn(3, 20))
>>> output = []
>>> for i in range(6):
... h = zoneout(h, h_prev)
... h, h_prev = rnn(input[i], h_prev), h
... output.append(h)
.. _Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activation:
https://arxiv.org/abs/1606.01305
"""
def __init__(self, p=None, inplace=False, mask=None):
super(Zoneout, self).__init__()
if p is None and mask is None:
raise ValueError("Either p or mask must be provided")
if p is not None and mask is not None:
raise ValueError("Only one of p and mask can be provided")
if p is not None and (p < 0 or p > 1):
raise ValueError("zoneout probability has to be between 0 and 1, "
"but got {}".format(p))
if mask is not None and \
not isinstance(mask, torch.ByteTensor) and \
not isinstance(mask, torch.cuda.ByteTensor):
raise ValueError("mask must be a ByteTensor")
self.p = p
self.inplace = inplace
self.mask = mask
def forward(self, previous_input, current_input):
return F.zoneout(previous_input, current_input, self.p, self.mask,
self.training, self.inplace)
def __repr__(self):
inplace_str = ', inplace' if self.inplace else ''
if self.mask is not None:
mask_str = 'mask=ByteTensor of size ' + \
'x'.join(str(size) for size in self.mask.size())
else:
mask_str = 'p=' + str(self.p)
return self.__class__.__name__ + '(' + mask_str + inplace_str + ')'
def zoneout(previous_input, current_input, p=None, mask=None, training=False,
inplace=False):
return Zoneout.apply(previous_input, current_input, p,
mask, training, inplace)