-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlinear_scheduler.py
58 lines (48 loc) · 2.3 KB
/
linear_scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch.optim as optim
import torch.nn as nn
class linear_scheduler_warmup(nn.Module):
r"""implement scheduler that handles the progression of the learning rate. In the warm-up
phase the lr increases from a start value linearly to a maximum and afterwards linearly
decreases to to a set end value
Args:
optimizer: fetch learning rate from and adjust
start: after warm-up, the learning rate starts to decrease from the start value
end: final learning rate
warmup_start: first learning rate
total_steps: number of training iterations and subsequent scheduler steps. len(trainloader)
ratio: lenght total-steps / lenght warmup-steps
"""
def __init__(self,
optimizer: nn.Module,
start: float = 1e-4,
end: float = 1e-6,
warmup_start: float = 1e-6,
total_steps: int = 24000,
ratio: float = 15):
super(linear_scheduler_warmup, self).__init__()
assert start > 0 and end > 0 and warmup_start > 0, "learning rate must be strictly positive."
assert ratio > 1, "Ratio must be larger than one."
if optimizer.param_groups[0]["lr"] != warmup_start:
raise Warning(f"initial learning rate is changed to {warmup_start}")
self.optimizer = optimizer
self.end = end
self.start = start
self.warmup_start = warmup_start
# set steps
self.total_steps = total_steps
self.warmup_steps = self.total_steps // ratio
# calculate the slopes of the linear increase/ decrease
self.sloap_train = (self.end - self.start) / (self.total_steps - self.warmup_steps)
self.sloap_warmup = (self.start - self.warmup_start) / self.warmup_steps
# set the first learning rate
self.optimizer.param_groups[0]["lr"] = self.warmup_start
self.calls = 0
def step(self)->None:
r"""adjusts the learning rate of the optimizer.
"""
if self.calls < self.warmup_steps:
self.optimizer.param_groups[0]["lr"] += self.sloap_warmup
self.calls +=1
else:
self.optimizer.param_groups[0]["lr"] += self.sloap_train
self.calls +=1