-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Correct installation path. Build and test badge. (#14)
- Fixes installation path in README (#12) - Adds build and test workflow plus README badge - Fixes linter issues
- Loading branch information
1 parent
562f64b
commit f691bf7
Showing
12 changed files
with
96 additions
and
95 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
import functools | ||
import math | ||
from collections.abc import Iterable | ||
from typing import Tuple, Type | ||
from typing import Callable, List, Tuple, Type, no_type_check | ||
|
||
import torch | ||
|
||
|
@@ -13,6 +13,7 @@ | |
RMSprop = torch.optim.RMSprop | ||
|
||
|
||
@no_type_check | ||
def partial(optim_cls: Type[torch.optim.Optimizer], **optim_kwargs): | ||
""" | ||
Partially instantiates an optimizer class. This approach is preferred over | ||
|
@@ -60,8 +61,6 @@ class PartialOptimizer(optim_cls): | |
|
||
# written by Hugo Berard ([email protected]) while at Facebook. | ||
|
||
required = object() | ||
|
||
|
||
class ExtragradientOptimizer(torch.optim.Optimizer): | ||
"""Base class for optimizers with extrapolation step. | ||
|
@@ -75,7 +74,7 @@ class ExtragradientOptimizer(torch.optim.Optimizer): | |
|
||
def __init__(self, params: Iterable, defaults: dict): | ||
super(ExtragradientOptimizer, self).__init__(params, defaults) | ||
self.params_copy = [] | ||
self.params_copy: List[torch.nn.Parameter] = [] | ||
|
||
def update(self, p, group): | ||
raise NotImplementedError | ||
|
@@ -101,7 +100,7 @@ def extrapolation(self): | |
# Update the current parameters | ||
p.data.add_(u) | ||
|
||
def step(self, closure: callable = None): | ||
def step(self, closure: Callable = None): | ||
"""Performs a single optimization step. | ||
Args: | ||
|
@@ -148,7 +147,8 @@ class ExtraSGD(ExtragradientOptimizer): | |
.. note:: | ||
The implementation of SGD with Momentum/Nesterov subtly differs from | ||
:cite:t:`sutskever2013initialization`. and implementations in some other frameworks. | ||
:cite:t:`sutskever2013initialization`. and implementations in some other | ||
frameworks. | ||
Considering the specific case of Momentum, the update can be written as | ||
|
@@ -172,13 +172,13 @@ class ExtraSGD(ExtragradientOptimizer): | |
def __init__( | ||
self, | ||
params: Iterable, | ||
lr: float = required, | ||
lr: float, | ||
momentum: float = 0, | ||
dampening: float = 0, | ||
weight_decay: float = 0, | ||
nesterov: bool = False, | ||
): | ||
if lr is not required and lr < 0.0: | ||
if lr is None or lr < 0.0: | ||
raise ValueError("Invalid learning rate: {}".format(lr)) | ||
if momentum < 0.0: | ||
raise ValueError("Invalid momentum value: {}".format(momentum)) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.