-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathtest_extrapolation.py
140 lines (106 loc) · 4.58 KB
/
test_extrapolation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#!/usr/bin/env python
"""Tests for Extrapolation optimizers."""
import functools
import pytest
import testing_utils
import torch
# Import basic closure example from helpers
import toy_2d_problem
import cooper
@pytest.mark.parametrize("aim_device", ["cpu", "cuda"])
@pytest.mark.parametrize("primal_optimizer_str", ["ExtraSGD", "ExtraAdam"])
def test_extrapolation(aim_device, primal_optimizer_str):
"""
Simple test on a bi-variate quadratic programming problem
min x**2 + 2*y**2
st.
x + y >= 1
x**2 + y <= 1
Verified solution from WolframAlpha (x=2/3, y=1/3)
Link to WolframAlpha query: https://tinyurl.com/ye8dw6t3
"""
device, skip = testing_utils.get_device_skip(aim_device, torch.cuda.is_available())
if skip.do_skip:
pytest.skip(skip.skip_reason)
params = torch.nn.Parameter(torch.tensor([0.0, -1.0], device=device))
try:
optimizer_class = getattr(cooper.optim, primal_optimizer_str)
except:
optimizer_class = getattr(torch.optim, primal_optimizer_str)
primal_optimizer = optimizer_class([params], lr=1e-2)
dual_optimizer = cooper.optim.partial_optimizer(cooper.optim.ExtraSGD, lr=1e-2)
cmp = toy_2d_problem.Toy2dCMP(use_ineq=True)
formulation = cooper.LagrangianFormulation(cmp)
coop = cooper.ConstrainedOptimizer(
formulation=formulation,
primal_optimizer=primal_optimizer,
dual_optimizer=dual_optimizer,
dual_restarts=False,
)
for step_id in range(2000):
coop.zero_grad()
lagrangian = formulation.composite_objective(cmp.closure, params)
formulation.custom_backward(lagrangian)
coop.step(cmp.closure, params)
if device == "cuda":
assert cmp.state.loss.is_cuda
assert cmp.state.eq_defect is None or cmp.state.eq_defect.is_cuda
assert cmp.state.ineq_defect is None or cmp.state.ineq_defect.is_cuda
# TODO: Why do we need such relatex tolerance for this test to pass?
if primal_optimizer == "ExtraSGD":
atol = 1e-8
else:
atol = 1e-3
assert torch.allclose(params[0], torch.tensor(2.0 / 3.0), atol=atol)
assert torch.allclose(params[1], torch.tensor(1.0 / 3.0), atol=atol)
@pytest.mark.parametrize("aim_device", ["cpu", "cuda"])
@pytest.mark.parametrize("primal_optimizer", ["ExtraSGD"])
def test_manual_extrapolation(aim_device, primal_optimizer):
"""
Simple test on a bi-variate quadratic programming problem
min x**2 + 2*y**2
st.
x + y >= 1
x**2 + y <= 1
Verified solution from WolframAlpha (x=2/3, y=1/3)
Link to WolframAlpha query: https://tinyurl.com/ye8dw6t3
"""
device, skip = testing_utils.get_device_skip(aim_device, torch.cuda.is_available())
if skip.do_skip:
pytest.skip(skip.skip_reason)
params = torch.nn.Parameter(torch.tensor([0.0, -1.0], device=device))
try:
optimizer_class = getattr(cooper.optim, primal_optimizer)
except:
optimizer_class = getattr(torch.optim, primal_optimizer)
primal_optimizer = optimizer_class([params], lr=1e-2)
dual_optimizer = cooper.optim.partial_optimizer(cooper.optim.ExtraSGD, lr=1e-2)
cmp = toy_2d_problem.Toy2dCMP(use_ineq=True)
formulation = cooper.LagrangianFormulation(cmp)
coop = cooper.ConstrainedOptimizer(
formulation=formulation,
primal_optimizer=primal_optimizer,
dual_optimizer=dual_optimizer,
dual_restarts=False,
)
# Helper function to instantiate tensors in correct device
mktensor = functools.partial(torch.tensor, device=device)
coop.zero_grad()
lagrangian = formulation.composite_objective(cmp.closure, params)
# Check loss, proxy and non-proxy defects after forward pass
assert torch.allclose(lagrangian, mktensor(2.0))
assert torch.allclose(cmp.state.loss, mktensor(2.0))
assert torch.allclose(cmp.state.ineq_defect, mktensor([2.0, -2.0]))
assert cmp.state.eq_defect is None
# Multiplier initialization
assert torch.allclose(formulation.state()[0], mktensor([0.0, 0.0]))
assert formulation.state()[1] is None
# Check primal and dual gradients after backward. Dual gradient must match
# ineq_defect
formulation.custom_backward(lagrangian)
assert torch.allclose(params.grad, mktensor([0.0, -4.0]))
assert torch.allclose(formulation.state()[0].grad, cmp.state.ineq_defect)
# Check updated primal and dual variable values
coop.step(cmp.closure, params)
assert torch.allclose(params, mktensor([2.0e-4, -0.9614]))
assert torch.allclose(formulation.state()[0], mktensor([0.0196, 0.0]))