-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathtest_proxy_constraints.py
103 lines (79 loc) · 3.64 KB
/
test_proxy_constraints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#!/usr/bin/env python
"""Tests for Constrained Optimizer class."""
import functools
import pytest
import testing_utils
import torch
# Import basic closure example from helpers
import toy_2d_problem
import cooper
@pytest.mark.parametrize("aim_device", ["cpu", "cuda"])
def test_toy_problem(aim_device):
"""
Simple test on a bi-variate quadratic programming problem
min x**2 + 2*y**2
st.
x + y >= 1
x**2 + y <= 1.
Verified solution from WolframAlpha (x=2/3, y=1/3)
Link to WolframAlpha query: https://tinyurl.com/ye8dw6t3
"""
device, skip = testing_utils.get_device_skip(aim_device, torch.cuda.is_available())
if skip.do_skip:
pytest.skip(skip.skip_reason)
params = torch.nn.Parameter(torch.tensor([0.0, -1.0], device=device))
primal_optimizer = torch.optim.SGD([params], lr=5e-2, momentum=0.0)
dual_optimizer = cooper.optim.partial_optimizer(torch.optim.SGD, lr=1e-2)
cmp = toy_2d_problem.Toy2dCMP(use_ineq=True, use_proxy_ineq=True)
formulation = cooper.LagrangianFormulation(cmp)
coop = cooper.ConstrainedOptimizer(
formulation=formulation,
primal_optimizer=primal_optimizer,
dual_optimizer=dual_optimizer,
dual_restarts=False,
alternating=False,
)
# Helper function to instantiate tensors in correct device
mktensor = functools.partial(torch.tensor, device=device)
# ----------------------- First iteration -----------------------
coop.zero_grad()
lagrangian = formulation.composite_objective(cmp.closure, params)
# Check loss, proxy and non-proxy defects after forward pass
assert torch.allclose(lagrangian, mktensor(2.0))
assert torch.allclose(cmp.state.loss, mktensor(2.0))
assert torch.allclose(cmp.state.ineq_defect, mktensor([2.0, -2.0]))
assert torch.allclose(cmp.state.proxy_ineq_defect, mktensor([2.0, -1.9]))
assert cmp.state.eq_defect is None
assert cmp.state.proxy_eq_defect is None
# Multiplier initialization
assert torch.allclose(formulation.state()[0], mktensor([0.0, 0.0]))
assert formulation.state()[1] is None
# Check primal and dual gradients after backward. Dual gradient must match
# ineq_defect
formulation.custom_backward(lagrangian)
assert torch.allclose(params.grad, mktensor([0.0, -4.0]))
assert torch.allclose(formulation.state()[0].grad, cmp.state.ineq_defect)
# Check updated primal and dual variable values
coop.step()
assert torch.allclose(params, mktensor([0.0, -0.8]))
assert torch.allclose(formulation.state()[0], mktensor([0.02, 0.0]))
# ----------------------- Second iteration -----------------------
coop.zero_grad()
lagrangian = formulation.composite_objective(cmp.closure, params)
# Check loss, proxy and non-proxy defects after forward pass
assert torch.allclose(lagrangian, mktensor(1.316))
assert torch.allclose(cmp.state.loss, mktensor(1.28))
assert torch.allclose(cmp.state.ineq_defect, mktensor([1.8, -1.8]))
assert torch.allclose(cmp.state.proxy_ineq_defect, mktensor([1.8, -1.72]))
# Check primal and dual gradients after backward. Dual gradient must match
# ineq_defect
formulation.custom_backward(lagrangian)
assert torch.allclose(params.grad, mktensor([-0.018, -3.22]))
assert torch.allclose(formulation.state()[0].grad, cmp.state.ineq_defect)
# Check updated primal and dual variable values
coop.step()
assert torch.allclose(params, mktensor([9e-4, -0.639]))
assert torch.allclose(formulation.state()[0], mktensor([0.038, 0.0]))
if device == "cuda":
assert cmp.state.loss.is_cuda
assert cmp.state.ineq_defect.is_cuda