-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathtest_optimizer.py
69 lines (53 loc) · 2.11 KB
/
test_optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#!/usr/bin/env python
"""Tests for Constrained Optimizer class. This test already verifies that the
code behaves as expected for an unconstrained setting."""
import pytest
import testing_utils
import torch
# Import basic closure example from helpers
import toy_2d_problem
import cooper
@pytest.mark.parametrize("aim_device", ["cpu", "cuda"])
@pytest.mark.parametrize("use_ineq", [True, False])
def test_toy_problem(aim_device, use_ineq):
"""
Simple test on a bi-variate quadratic programming problem
min x**2 + 2*y**2
st.
x + y >= 1
x**2 + y <= 1
Verified solution from WolframAlpha (x=2/3, y=1/3)
Link to WolframAlpha query: https://tinyurl.com/ye8dw6t3
"""
device, skip = testing_utils.get_device_skip(aim_device, torch.cuda.is_available())
if skip.do_skip:
pytest.skip(skip.skip_reason)
params = torch.nn.Parameter(torch.tensor([0.0, -1.0], device=device))
primal_optimizer = torch.optim.SGD([params], lr=1e-2, momentum=0.3)
if use_ineq:
dual_optimizer = cooper.optim.partial_optimizer(torch.optim.SGD, lr=1e-2)
else:
dual_optimizer = None
cmp = toy_2d_problem.Toy2dCMP(use_ineq=use_ineq)
formulation = cooper.LagrangianFormulation(cmp)
coop = cooper.ConstrainedOptimizer(
formulation=formulation,
primal_optimizer=primal_optimizer,
dual_optimizer=dual_optimizer,
dual_restarts=True,
)
for step_id in range(1500):
coop.zero_grad()
lagrangian = formulation.composite_objective(cmp.closure, params)
formulation.custom_backward(lagrangian)
coop.step()
if device == "cuda":
assert cmp.state.loss.is_cuda
assert cmp.state.eq_defect is None or cmp.state.eq_defect.is_cuda
assert cmp.state.ineq_defect is None or cmp.state.ineq_defect.is_cuda
if use_ineq:
assert torch.allclose(params[0], torch.tensor(2.0 / 3.0))
assert torch.allclose(params[1], torch.tensor(1.0 / 3.0))
else:
# This unconstrained quadratic form has minimum at the origin
assert torch.allclose(params, torch.tensor(0.0))