diff --git a/cooper/formulation/__init__.py b/cooper/formulation/__init__.py index 494080e1..a06476fe 100644 --- a/cooper/formulation/__init__.py +++ b/cooper/formulation/__init__.py @@ -1,3 +1,4 @@ from .augmented_lagrangian import AugmentedLagrangianFormulation from .formulation import Formulation, UnconstrainedFormulation from .lagrangian import LagrangianFormulation, ProxyLagrangianFormulation +from .lagrangian_model import LagrangianModelFormulation diff --git a/cooper/formulation/augmented_lagrangian.py b/cooper/formulation/augmented_lagrangian.py index 28aabe4b..63ff1853 100644 --- a/cooper/formulation/augmented_lagrangian.py +++ b/cooper/formulation/augmented_lagrangian.py @@ -90,7 +90,7 @@ def weighted_violation( # to update the value of the multipliers by lazily filling the # multiplier gradients in `backward` - # TODO (JGP): Verify that call to backward is general enough for + # TODO (gallego-posada): Verify that call to backward is general enough for # Lagrange Multiplier models violation_for_update = torch.sum(multipliers * defect.detach()) self.update_accumulated_violation(update=violation_for_update) diff --git a/cooper/formulation/formulation.py b/cooper/formulation/formulation.py index adafcb24..df786cf5 100644 --- a/cooper/formulation/formulation.py +++ b/cooper/formulation/formulation.py @@ -1,10 +1,13 @@ import abc -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Union import torch from cooper.problem import CMPState, ConstrainedMinimizationProblem +# from .lagrangian_model import CMPModelState + + # Formulation, and some other classes below, are heavily inspired by the design # of the TensorFlow Constrained Optimization (TFCO) library : # https://github.com/google-research/tensorflow_constrained_optimization @@ -34,6 +37,11 @@ def state(self): """Returns the internal state of formulation (e.g. multipliers).""" pass + @abc.abstractmethod + def flip_dual_gradients(self): + """Flips the sign of the dual gradients.""" + pass + @property @abc.abstractmethod def is_state_created(self): @@ -59,7 +67,8 @@ def backward(self, *args, **kwargs): formulation.""" pass - def write_cmp_state(self, cmp_state: CMPState): + # TODO(daoterog): fix circular import type hint can be correct + def write_cmp_state(self, cmp_state: CMPState): # Union[CMPState, CMPModelState]): """Provided that the formulation is linked to a `ConstrainedMinimizationProblem`, writes a CMPState to the CMP.""" @@ -119,6 +128,11 @@ def load_state_dict(self, state_dict: dict): """ pass + def flip_dual_gradients(self): + """Flips the sign of the dual gradients. This is a no-op for + unconstrained formulations.""" + pass + def compute_lagrangian( self, closure: Callable[..., CMPState], diff --git a/cooper/formulation/lagrangian.py b/cooper/formulation/lagrangian.py index 9266ddbc..7e3147df 100644 --- a/cooper/formulation/lagrangian.py +++ b/cooper/formulation/lagrangian.py @@ -42,7 +42,7 @@ def __init__( # Store user-provided initializations for dual variables self.ineq_init = ineq_init self.eq_init = eq_init - + # TODO(gallego-posada): comment what is the meaning of this object self.accumulated_violation_dot_prod: torch.Tensor = None @property @@ -193,11 +193,11 @@ def state_dict(self) -> Dict[str, Any]: """ Generates the state dictionary for a Lagrangian formulation. """ - + # TODO(gallego-posada): fix in next PR state_dict = { "ineq_multipliers": self.ineq_multipliers, "eq_multipliers": self.eq_multipliers, - "accumulated_violation_dot_prod": self.accumulated_violation_dot_prod, + # "accumulated_violation_dot_prod": self.accumulated_violation_dot_prod, } return state_dict @@ -220,6 +220,14 @@ def load_state_dict(self, state_dict: Dict[str, Any]): ), "LagrangianFormulation received unknown key: {}".format(key) setattr(self, key, val) + def flip_dual_gradients(self): + """ + Flip the sign of the gradients of the dual variables. + """ + for multiplier in self.state(): + if multiplier is not None: + multiplier.grad.mul_(-1.0) + class LagrangianFormulation(BaseLagrangianFormulation): """ @@ -353,7 +361,7 @@ def weighted_violation( # on `accumulated_violation_dot_prod`. This enables easy # extensibility to multiplier classes beyond DenseMultiplier. - # TODO (JGP): Verify that call to backward is general enough for + # TODO (gallego-posada): Verify that call to backward is general enough for # Lagrange Multiplier models violation_for_update = torch.sum(multipliers * defect.detach()) self.update_accumulated_violation(update=violation_for_update) diff --git a/cooper/formulation/lagrangian_model.py b/cooper/formulation/lagrangian_model.py new file mode 100644 index 00000000..38e7687e --- /dev/null +++ b/cooper/formulation/lagrangian_model.py @@ -0,0 +1,383 @@ +from dataclasses import dataclass +from typing import ( + Callable, + List, + Optional, + Tuple, + Union, + no_type_check, + Dict, + Any, + Iterator, +) + +import torch + +from .lagrangian import BaseLagrangianFormulation +from cooper.multipliers import MultiplierModel + + +@dataclass +class CMPModelState: + """ + Represents the "state" of a Constrained Minimization Problem in terms of + the value of its loss and constraint violations/defects. The main difference between + this object and `CMPState` is that it also stores the features for the constraints + that are going to be passed to the multiplier models to predict the Lagrange + multipliers. + + Args: + loss: Value of the loss or main objective to be minimized :math:`f(x)` + ineq_defect: Violation of the inequality constraints :math:`g(x)` + eq_defect: Violation of the equality constraints :math:`h(x)` + proxy_ineq_defect: Differentiable surrogate for the inequality + constraints as proposed by :cite:t:`cotter2019JMLR`. + proxy_eq_defect: Differentiable surrogate for the equality constraints + as proposed by :cite:t:`cotter2019JMLR`. + ineq_constraint_features: Features for the inequality constraints that are + going to be passed to the inequality multiplier model to predict the Lagrange + multipliers. + eq_constraint_features: Features for the equality constraints that are going to + be passed to the equality multiplier model to predict the Lagrange multipliers. + misc: Optional additional information to be store along with the state + of the CMP + """ + + loss: Optional[torch.Tensor] = None + ineq_defect: Optional[torch.Tensor] = None + eq_defect: Optional[torch.Tensor] = None + proxy_ineq_defect: Optional[torch.Tensor] = None + proxy_eq_defect: Optional[torch.Tensor] = None + ineq_constraint_features: Optional[torch.Tensor] = None + eq_constraint_features: Optional[torch.Tensor] = None + misc: Optional[dict] = None + + def as_tuple(self) -> tuple: + return ( + self.loss, + self.ineq_defect, + self.eq_defect, + self.proxy_ineq_defect, + self.proxy_eq_defect, + self.ineq_constraint_features, + self.eq_constraint_features, + self.misc, + ) + + +class LagrangianModelFormulation(BaseLagrangianFormulation): + """ + Computes the Lagrangian based on the predictions of a `MultiplierModel`. This + formulation is useful when the Lagrange multipliers are not kept explicitly, but + are instead predicted by a model, e.i. neural network. This formulation is meant to + be used along the :py:class:`~cooper.multipliers.MultiplierModel`. + + Attributes: + ineq_multiplier_model: The model used to predict the Lagrange multipliers + associated with the inequality constraints. If ``None``, the + :py:meth:`~cooper.formulation.lagrangian_model.LagrangianModelFormulation.state` + method will not return the Lagrange multipliers associated with the + inequality constraints. + eq_multiplier_model: The model used to predict the Lagrange multipliers + associated with the equality constraints. If ``None``, the + :py:meth:`~cooper.formulation.lagrangian_model.LagrangianModelFormulation.state` + method will not return the Lagrange multipliers associated with the + equality constraints. + **kwargs: Additional keyword arguments to be passed to the + """ + + def __init__( + self, + *args, + ineq_multiplier_model: Optional[MultiplierModel] = None, + eq_multiplier_model: Optional[MultiplierModel] = None, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.ineq_multiplier_model = ineq_multiplier_model + self.eq_multiplier_model = eq_multiplier_model + + self.base_sanity_checks() + + def base_sanity_checks(self): + """ + Perform sanity checks on the initialization of ``LagrangianModelFormulation``. + """ + + if self.ineq_multiplier_model is None and self.eq_multiplier_model is None: + # This formulation cannot perform any prediction if no multiplier model is + # provided. + raise ValueError("At least one multiplier model must be provided.") + + if self.ineq_multiplier_model is not None and not isinstance( + self.ineq_multiplier_model, MultiplierModel + ): + raise ValueError("The `ineq_multiplier_model` must be a `MultiplierModel`.") + + if self.eq_multiplier_model is not None and not isinstance( + self.eq_multiplier_model, MultiplierModel + ): + raise ValueError("The `eq_multiplier_model` must be a `MultiplierModel`.") + + @property + def dual_parameters(self) -> List[torch.Tensor]: + """Returns a list gathering all dual parameters.""" + all_dual_params = [] + + for mult in [self.ineq_multiplier_model, self.eq_multiplier_model]: + if mult is not None: + all_dual_params.extend(list(mult.parameters())) + + return all_dual_params + + def state(self) -> Tuple[Union[None, Iterator[torch.nn.Parameter]]]: + + """ + Collects all dual variables and returns a tuple containing their + :py:class:`Iterator[torch.nn.Parameter]` values. Note that the *values* + correspond to the parameters of the `MultiplierModel`. + """ + + if self.ineq_multiplier_model is None: + ineq_state = None + else: + ineq_state = self.ineq_multiplier_model.parameters() + + if self.eq_multiplier_model is None: + eq_state = None + else: + eq_state = self.eq_multiplier_model.parameters() + + return ineq_state, eq_state + + def create_state(self): + """This method is not implemented for this formulation. It originally + instantiates the dual variables, but in this formulation this is done since the + instantiation of the object, since it is necessary to provide a `MultiplerModel` + for each of the contraint types.""" + pass + + @property + def is_state_created(self): + """ + Returns ``True`` if any Multiplier Model have been initialized. + """ + return ( + self.ineq_multiplier_model is not None + or self.eq_multiplier_model is not None + ) + + def state_dict(self) -> Dict[str, Any]: + """ + Generates the state dictionary for a Lagrangian model formulation. + """ + + state_dict = { + "ineq_multiplier_model": self.ineq_multiplier_model.state_dict(), + "eq_multiplier_model": self.eq_multiplier_model.state_dict(), + } + return state_dict + + def load_state_dict(self, state_dict: Dict[str, Any]): + """ + Loads the state dictionary of a Lagrangian formulation. + + Args: + state_dict: state dictionary to be loaded. + """ + + known_attrs = ["ineq_multiplier_model", "eq_multiplier_model"] + + for key, val in state_dict.items(): + + if key not in known_attrs: + raise ValueError( + f"LagrangianModelFormulation received unknown key: {key}. Valid keys are {known_attrs}" + ) + + if key in ["ineq_multiplier_model", "eq_multiplier_model"]: + multiplier_model = getattr(self, key) + multiplier_model.load_state_dict(val) + + def flip_dual_gradients(self) -> None: + """ + Flips the sign of the gradients for the dual variables. This is useful + when using the dual formulation in conjunction with the alternating + update scheme. + """ + + for multiplier_model_state in self.state(): + if multiplier_model_state is not None: + for param in multiplier_model_state: + if param.grad is not None: + param.grad.mul_(-1.0) + + @no_type_check + def compute_lagrangian( + self, + closure: Callable[..., CMPModelState] = None, + *closure_args, + pre_computed_state: Optional[CMPModelState] = None, + write_state: bool = True, + **closure_kwargs, + ) -> torch.Tensor: + """ + Computes the Lagrangian of the problem, given the current state of the + optimization problem. This method is used to compute the loss function + for the optimization problem. + + Args: + closure: A function that returns a :py:class:`CMPModelState` object. This + function is used to compute the loss function and the constraint + violations. If ``None``, the ``pre_computed_state`` argument must be + provided. + *closure_args: Positional arguments to be passed to the ``closure`` + function. + pre_computed_state: A :py:class:`CMPModelState` object containing the + pre-computed loss function and constraint violations. If ``None``, + the ``closure`` argument must be provided. + write_state: If ``True``, the state of the optimization problem is + written to the ``cmp_model_state`` attribute of the :py:class:`CMPModelState` + object. + **closure_kwargs: Keyword arguments to be passed to the ``closure`` + function. + """ + + assert ( + closure is not None or pre_computed_state is not None + ), "At least one of closure or pre_computed_state must be provided" + + if pre_computed_state is not None: + cmp_model_state = pre_computed_state + else: + cmp_model_state = closure(*closure_args, **closure_kwargs) + + if write_state: + self.write_cmp_state(cmp_model_state) + + # Extract values from ProblemState object + loss = cmp_model_state.loss + + # Purge previously accumulated constraint violations + self.update_accumulated_violation(update=None) + + # Compute contribution of the sampled constraint violations, weighted by the + # current multiplier values predicted by the multuplier model. + ineq_viol = self.weighted_violation(cmp_model_state, "ineq") + eq_viol = self.weighted_violation(cmp_model_state, "eq") + + # Lagrangian = loss + \sum_i multiplier_i * defect_i + lagrangian = loss + ineq_viol + eq_viol + + return lagrangian + + def weighted_violation( + self, cmp_model_state: CMPModelState, constraint_type: str + ) -> torch.Tensor: + """ + Computes the dot product between the current multipliers and the + constraint violations of type ``constraint_type``. The multiplier correspond to + the output of a `MultiplierModel` provided when the formulation was initialized. + The model is trained on `constraint_features` provided in the CMPModelState. + If proxy-constraints are provided in the :py:class:`.CMPModelState`, the non- + proxy (usually non-differentiable) constraints are used for computing the dot + product, while the "proxy-constraint" dot products are accumulated under + ``self.accumulated_violation_dot_prod``. + + Args: + cmp_model_state: current ``CMPModelState`` + constraint_type: type of constrained to be used, e.g. "eq" or "ineq". + """ + + defect = getattr(cmp_model_state, constraint_type + "_defect") + has_defect = defect is not None + + proxy_defect = getattr(cmp_model_state, "proxy_" + constraint_type + "_defect") + has_proxy_defect = proxy_defect is not None + + if not has_proxy_defect: + # If not given proxy constraints, then the regular defects are + # used for computing gradients and evaluating the multipliers + proxy_defect = defect + + if not has_defect: + # We should always have at least the "regular" defects, if not, then + # the problem instance does not have `constraint_type` constraints + violation = torch.tensor([0.0], device=cmp_model_state.loss.device) + else: + multiplier_model = getattr(self, constraint_type + "_multiplier_model") + + # Get multipliers by performing a prediction over the features of the + # sampled constraints + constraint_features = getattr( + cmp_model_state, constraint_type + "_constraint_features" + ) + + multipliers = multiplier_model.forward(constraint_features) + + # The violations are computed via inner product between the multipliers + # and the defects, they should have the same shape. If given proxy-defects + # then their shape has to be checked as well. + assert defect.shape == proxy_defect.shape == multipliers.shape + + # Store the multiplier values + setattr(self, constraint_type + "_multipliers", multipliers) + + # We compute (primal) gradients of this object with the sampled + # constraints + violation = torch.sum(multipliers.detach() * proxy_defect) + + # This is the violation of the "actual/hard" constraint. We use this + # to update the multipliers. + # The gradients for the dual variables are computed via a backward + # on `accumulated_violation_dot_prod`. This enables easy + # extensibility to multiplier classes beyond DenseMultiplier. + + # TODO (gallego-posada): Verify that call to backward is general enough for + # Lagrange Multiplier models + violation_for_update = torch.sum(multipliers * defect.detach()) + self.update_accumulated_violation(update=violation_for_update) + + return violation + + @no_type_check + def backward( + self, + lagrangian: torch.Tensor, + ignore_primal: bool = False, + ignore_dual: bool = False, + ): + """ + Performs the actual backward computation which populates the gradients + for the primal and dual variables. + + Args: + lagrangian: Value of the computed Lagrangian based on which the + gradients for the primal and dual variables are populated. + ignore_primal: If ``True``, only the gradients with respect to the + dual variables are populated (these correspond to the constraint + violations). This feature is mainly used in conjunction with + ``alternating`` updates, which require updating the multipliers + based on the constraints violation *after* having updated the + primal parameters. Defaults to False. + ignore_dual: If ``True``, the gradients with respect to the dual + variables are not populated. + """ + + if ignore_primal: + # Only compute gradients wrt Lagrange multipliers + # No need to call backward on Lagrangian as the dual variables have + # been detached when computing the `weighted_violation`s + pass + else: + # Compute gradients wrt _primal_ parameters only. + # The gradient for the dual variables is computed based on the + # non-proxy violations below. + lagrangian.backward() + + # Fill in the gradients for the dual variables based on the violation of + # the non-proxy constraints + if not ignore_dual: + dual_vars = self.dual_parameters + self.accumulated_violation_dot_prod.backward(inputs=dual_vars) diff --git a/cooper/multipliers/__init__.py b/cooper/multipliers/__init__.py index 5c3ef182..267360f2 100644 --- a/cooper/multipliers/__init__.py +++ b/cooper/multipliers/__init__.py @@ -1 +1,2 @@ +from .multiplier_model import MultiplierModel from .multipliers import BaseMultiplier, DenseMultiplier diff --git a/cooper/multipliers/multiplier_model.py b/cooper/multipliers/multiplier_model.py new file mode 100644 index 00000000..182cadb5 --- /dev/null +++ b/cooper/multipliers/multiplier_model.py @@ -0,0 +1,43 @@ +import abc +from typing import Iterator + +import torch + +from .multipliers import BaseMultiplier + + +class MultiplierModel(BaseMultiplier, metaclass=abc.ABCMeta): + """ + A multiplier model. Holds a :py:class:`~torch.nn.Module`, which predicts + the value of the Lagrange multipliers associated with the equality or + inequality constraints of a + :py:class:`~cooper.problem.ConstrainedMinimizationProblem`. This is class is meant + to be inherited by the user to implement their own multiplier model. + """ + + def __init__(self): + super().__init__() + + @abc.abstractmethod + def forward(self, constraint_features: torch.Tensor): + """ + Returns the *actual* value of the multipliers by passing the "features" of the + constraint to predict the corresponding multiplier. + """ + pass + + @property + def grad(self) -> Iterator[torch.Tensor]: + raise RuntimeError("""grad method does not exist for MultiplierModel.""") + + @property + def shape(self): + raise RuntimeError("""shape method does not exist for MultiplierModel.""") + + def project_(self): + raise RuntimeError("""project_ method does not exist for MultiplierModel.""") + + def restart_if_feasible_(self): + raise RuntimeError( + """restart_if_feasible_ method does not exist for MultiplierModel.""" + ) diff --git a/cooper/optim/constrained_optimizers/alternating_optimizer.py b/cooper/optim/constrained_optimizers/alternating_optimizer.py index 86cae4dd..2ab619f7 100644 --- a/cooper/optim/constrained_optimizers/alternating_optimizer.py +++ b/cooper/optim/constrained_optimizers/alternating_optimizer.py @@ -7,8 +7,13 @@ import torch -from cooper.formulation import AugmentedLagrangianFormulation, Formulation +from cooper.formulation import ( + AugmentedLagrangianFormulation, + Formulation, + LagrangianModelFormulation, +) from cooper.problem import CMPState +from cooper.formulation.lagrangian_model import CMPModelState from .constrained_optimizer import ConstrainedOptimizer @@ -40,9 +45,9 @@ def __init__( def step( self, - closure: Optional[Callable[..., CMPState]] = None, + closure: Optional[Callable[..., Union[CMPState, CMPModelState]]] = None, *closure_args, - defect_fn: Optional[Callable[..., CMPState]] = None, + defect_fn: Optional[Callable[..., Union[CMPState, CMPModelState]]] = None, **closure_kwargs, ): """ @@ -134,14 +139,14 @@ def dual_step(self): # Flip gradients for multipliers to perform ascent. # We only do the flipping *right before* applying the optimizer step to # avoid accidental double sign flips. - for multiplier in self.formulation.state(): - if multiplier is not None: - multiplier.grad.mul_(-1.0) + self.formulation.flip_dual_gradients() # Update multipliers based on current constraint violations (gradients) self.dual_optimizer.step() - if self.formulation.ineq_multipliers is not None: + if self.formulation.ineq_multipliers is not None and not isinstance( + self.formulation, LagrangianModelFormulation + ): if self.dual_restarts: # "Reset" value of inequality multipliers to zero as soon as # solution becomes feasible diff --git a/cooper/optim/constrained_optimizers/constrained_optimizer.py b/cooper/optim/constrained_optimizers/constrained_optimizer.py index bfe9948d..a137428c 100644 --- a/cooper/optim/constrained_optimizers/constrained_optimizer.py +++ b/cooper/optim/constrained_optimizers/constrained_optimizer.py @@ -65,6 +65,8 @@ def base_sanity_checks(self): Perform sanity checks on the initialization of ``ConstrainedOptimizer``. """ + # TODO(daoterog): ensure that the dual optimizer and the dual scheduler is initialized when LagrangeModelFormulation + if self.dual_optimizer is None: raise RuntimeError("No dual optimizer was provided.") diff --git a/cooper/optim/constrained_optimizers/extrapolation_optimizer.py b/cooper/optim/constrained_optimizers/extrapolation_optimizer.py index c56d4883..dc3b14ad 100644 --- a/cooper/optim/constrained_optimizers/extrapolation_optimizer.py +++ b/cooper/optim/constrained_optimizers/extrapolation_optimizer.py @@ -7,9 +7,14 @@ import torch -from cooper.formulation import AugmentedLagrangianFormulation, Formulation +from cooper.formulation import ( + AugmentedLagrangianFormulation, + Formulation, + LagrangianModelFormulation, +) from cooper.optim.extra_optimizers import ExtragradientOptimizer from cooper.problem import CMPState +from cooper.formulation.lagrangian_model import CMPModelState from .constrained_optimizer import ConstrainedOptimizer @@ -123,9 +128,9 @@ def custom_sanity_checks(self): def step( self, - closure: Optional[Callable[..., CMPState]] = None, + closure: Optional[Callable[..., Union[CMPState, CMPModelState]]] = None, *closure_args, - defect_fn: Optional[Callable[..., CMPState]] = None, + defect_fn: Optional[Callable[..., Union[CMPState, CMPModelState]]] = None, **closure_kwargs, ): """ @@ -188,9 +193,7 @@ def dual_step(self, call_extrapolation=False): # Flip gradients for multipliers to perform ascent. # We only do the flipping *right before* applying the optimizer step to # avoid accidental double sign flips. - for multiplier in self.formulation.state(): - if multiplier is not None: - multiplier.grad.mul_(-1.0) + self.formulation.flip_dual_gradients() # Update multipliers based on current constraint violations (gradients) if call_extrapolation: @@ -198,7 +201,9 @@ def dual_step(self, call_extrapolation=False): else: self.dual_optimizer.step() - if self.formulation.ineq_multipliers is not None: + if self.formulation.ineq_multipliers is not None and not isinstance( + self.formulation, LagrangianModelFormulation + ): if self.dual_restarts: # "Reset" value of inequality multipliers to zero as soon as # solution becomes feasible diff --git a/cooper/optim/constrained_optimizers/simultaneous_optimizer.py b/cooper/optim/constrained_optimizers/simultaneous_optimizer.py index 2f1c3f53..e65c9876 100644 --- a/cooper/optim/constrained_optimizers/simultaneous_optimizer.py +++ b/cooper/optim/constrained_optimizers/simultaneous_optimizer.py @@ -7,8 +7,9 @@ import torch -from cooper.formulation import Formulation +from cooper.formulation import Formulation, LagrangianModelFormulation from cooper.problem import CMPState +from cooper.formulation.lagrangian_model import CMPModelState from .constrained_optimizer import ConstrainedOptimizer @@ -73,9 +74,9 @@ def __init__( def step( self, - closure: Optional[Callable[..., CMPState]] = None, + closure: Optional[Callable[..., Union[CMPState, CMPModelState]]] = None, *closure_args, - defect_fn: Optional[Callable[..., CMPState]] = None, + defect_fn: Optional[Callable[..., Union[CMPState, CMPModelState]]] = None, **closure_kwargs, ): """ @@ -97,14 +98,14 @@ def dual_step(self): # Flip gradients for multipliers to perform ascent. # We only do the flipping *right before* applying the optimizer step to # avoid accidental double sign flips. - for multiplier in self.formulation.state(): - if multiplier is not None: - multiplier.grad.mul_(-1.0) + self.formulation.flip_dual_gradients() # Update multipliers based on current constraint violations (gradients) self.dual_optimizer.step() - if self.formulation.ineq_multipliers is not None: + if self.formulation.ineq_multipliers is not None and not isinstance( + self.formulation, LagrangianModelFormulation + ): if self.dual_restarts: # "Reset" value of inequality multipliers to zero as soon as # solution becomes feasible diff --git a/cooper/utils/state_logger.py b/cooper/utils/state_logger.py index 7c3347fd..58b3f5b6 100644 --- a/cooper/utils/state_logger.py +++ b/cooper/utils/state_logger.py @@ -1,8 +1,9 @@ from collections import OrderedDict from copy import deepcopy -from typing import List +from typing import List, Union from cooper.problem import CMPState +from cooper.formulation.lagrangian_model import CMPModelState class StateLogger: @@ -22,7 +23,7 @@ def __init__(self, save_metrics: List[str]): def store_metrics( self, - cmp_state: CMPState, + cmp_state: Union[CMPState, CMPModelState], step_id: int, partial_dict: dict = None, ): diff --git a/tests/helpers/cooper_test_utils.py b/tests/helpers/cooper_test_utils.py index 193ec61a..a87dbaf2 100644 --- a/tests/helpers/cooper_test_utils.py +++ b/tests/helpers/cooper_test_utils.py @@ -10,6 +10,7 @@ import torch import cooper +from cooper.formulation.lagrangian_model import CMPModelState @dataclass @@ -33,6 +34,7 @@ def build_test_problem( dual_optim_cls, use_ineq, use_proxy_ineq, + use_mult_model, dual_restarts, alternating, primal_optim_kwargs={"lr": 1e-2}, @@ -48,8 +50,6 @@ def build_test_problem( if skip.do_skip: pytest.skip(skip.skip_reason) - cmp = Toy2dCMP(use_ineq=use_ineq, use_proxy_ineq=use_proxy_ineq) - if primal_init is None: primal_model.to(device) params = primal_model.parameters() @@ -74,12 +74,30 @@ def build_test_problem( else: primal_optimizers = [primal_optim_cls(params_, **primal_optim_kwargs)] - if use_ineq: + cmp = Toy2dCMP( + device, + use_ineq=use_ineq, + use_proxy_ineq=use_proxy_ineq, + use_mult_model=use_mult_model, + ) + + if use_mult_model or use_ineq: # Constrained case dual_optimizer = cooper.optim.partial_optimizer( dual_optim_cls, **dual_optim_kwargs ) + + if use_mult_model: + # Exclusive for the model formulation + ineq_multiplier_model = ToyMultiplierModel(3, 10, device) + formulation = cooper.formulation.LagrangianModelFormulation( + cmp, ineq_multiplier_model=ineq_multiplier_model + ) + + elif use_ineq: + # Constrained case different from model formulation formulation = formulation_cls(cmp) + else: # Unconstrained case dual_optimizer = None @@ -103,6 +121,23 @@ def build_test_problem( return TestProblemData(params, cmp, coop, formulation, device, mktensor) +class ToyMultiplierModel(cooper.multipliers.MultiplierModel): + """ + Simplest MultiplierModel possible, a linear model with a single output. + """ + + def __init__(self, n_params, n_hidden_units, device): + super().__init__() + self.linear1 = torch.nn.Linear(n_params, n_hidden_units, device=device) + self.linear2 = torch.nn.Linear(n_hidden_units, 1, device=device) + + def forward(self, constraint_features: torch.Tensor): + x = self.linear1(constraint_features) + x = torch.relu(x) + x = self.linear2(x) + return torch.reshape(torch.nn.functional.relu(x), (-1,)) + + class Toy2dCMP(cooper.ConstrainedMinimizationProblem): """ Simple test on a 2D quadratically-constrained quadratic programming problem @@ -122,14 +157,28 @@ class Toy2dCMP(cooper.ConstrainedMinimizationProblem): the constant contribution of the constraint level disappears. We include them here for readability. + This problem is designed to be used with the Lagrangian Model, thus, we define + constraint features to feed into the `Multiplier Model`. The first two features + correspont to the exponent of the `x` and `y` variables, respectively. The last + feature correspond to the slack term and the direction of the inequality constraint + (i.e. `-1` for `>=` and `1` for `<=`). + Verified solution from WolframAlpha of the original constrained problem: (x=2/3, y=1/3) Link to WolframAlpha query: https://tinyurl.com/ye8dw6t3 """ - def __init__(self, use_ineq=False, use_proxy_ineq=False): + def __init__( + self, device, use_ineq=False, use_proxy_ineq=False, use_mult_model=False + ): self.use_ineq = use_ineq self.use_proxy_ineq = use_proxy_ineq + self.use_mult_model = use_mult_model + + # Define constraint features + self.constraint_features = torch.tensor( + [[1.0, 1.0, -1.0], [2.0, 1.0, 1.0]], device=device + ) super().__init__() def eval_params(self, params): @@ -159,7 +208,7 @@ def defect_fn(self, params): # No equality constraints eq_defect = None - if self.use_ineq: + if self.use_ineq or self.use_mult_model: # Two inequality constraints ineq_defect = torch.stack( [ @@ -186,6 +235,18 @@ def defect_fn(self, params): ineq_defect = None proxy_ineq_defect = None + # Create inequality constraint features. The first feature is the exponent for + # the x, the second for the y, and the third is the slack term. The sign of the + # slack term depends on the constraint type (i.e. >= or <=). + if self.use_mult_model: + return CMPModelState( + loss=None, + eq_defect=eq_defect, + ineq_defect=ineq_defect, + proxy_ineq_defect=proxy_ineq_defect, + ineq_constraint_features=self.constraint_features, + ) + return cooper.CMPState( loss=None, eq_defect=eq_defect, diff --git a/tests/test_alternating_proxy.py b/tests/test_alternating_proxy.py index 73812e37..0e50e137 100644 --- a/tests/test_alternating_proxy.py +++ b/tests/test_alternating_proxy.py @@ -2,10 +2,11 @@ """Tests for Constrained Optimizer class.""" -import cooper_test_utils import pytest import torch +from .helpers import cooper_test_utils + @pytest.mark.parametrize("aim_device", ["cpu", "cuda"]) def test_manual_alternating_proxy(aim_device): @@ -18,6 +19,7 @@ def test_manual_alternating_proxy(aim_device): dual_optim_cls=torch.optim.SGD, use_ineq=True, use_proxy_ineq=True, + use_mult_model=False, dual_restarts=False, alternating=True, primal_optim_kwargs={"lr": 5e-2, "momentum": 0.0}, diff --git a/tests/test_alternating_update.py b/tests/test_alternating_update.py index 7e2cacb4..53661586 100644 --- a/tests/test_alternating_update.py +++ b/tests/test_alternating_update.py @@ -2,10 +2,11 @@ """Tests for Extrapolation optimizers.""" -import cooper_test_utils import pytest import torch +from .helpers import cooper_test_utils + def problem_data(aim_device, alternating): @@ -16,6 +17,7 @@ def problem_data(aim_device, alternating): dual_optim_cls=torch.optim.SGD, use_ineq=True, use_proxy_ineq=False, + use_mult_model=False, dual_restarts=False, alternating=alternating, ) diff --git a/tests/test_augmented_lagrangian.py b/tests/test_augmented_lagrangian.py index ac77b8e1..84cf59ec 100644 --- a/tests/test_augmented_lagrangian.py +++ b/tests/test_augmented_lagrangian.py @@ -2,12 +2,13 @@ """Tests for Augmented Lagrangian Formulation class.""" -import cooper_test_utils import pytest import torch import cooper +from .helpers import cooper_test_utils + def test_augmented_lagrangian_formulation(): class DummyCMP(cooper.ConstrainedMinimizationProblem): @@ -54,6 +55,7 @@ def test_convergence_augmented_lagrangian(aim_device): dual_optim_cls=torch.optim.SGD, use_ineq=True, use_proxy_ineq=False, + use_mult_model=False, dual_restarts=False, alternating=True, primal_optim_kwargs={"lr": 1e-2}, @@ -108,6 +110,7 @@ def test_manual_augmented_lagrangian(aim_device): dual_optim_cls=torch.optim.SGD, use_ineq=True, use_proxy_ineq=False, + use_mult_model=False, dual_restarts=False, alternating=True, primal_optim_kwargs={"lr": 1e-2}, diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 94d780d9..d685c06c 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -6,14 +6,15 @@ import os import tempfile -# Import basic closure example from helpers -import cooper_test_utils import pytest import torch import cooper from cooper.utils import validate_state_dicts +# Import basic closure example from helpers +from .helpers import cooper_test_utils + def train_for_n_steps(coop, cmp, params, n_step=100): @@ -76,6 +77,7 @@ def test_checkpoint(aim_device, use_ineq, multiple_optimizers): dual_optim_cls=torch.optim.SGD, use_ineq=use_ineq, use_proxy_ineq=False, + use_mult_model=False, dual_restarts=True, alternating=False, primal_optim_kwargs=primal_optim_kwargs, diff --git a/tests/test_extrapolation.py b/tests/test_extrapolation.py index 1131d00d..8125c96c 100644 --- a/tests/test_extrapolation.py +++ b/tests/test_extrapolation.py @@ -2,13 +2,14 @@ """Tests for Extrapolation optimizers.""" -# Import basic closure example from helpers -import cooper_test_utils import pytest import torch import cooper +# Import basic closure example from helpers +from .helpers import cooper_test_utils + def problem_data(aim_device, primal_optim_cls): @@ -19,6 +20,7 @@ def problem_data(aim_device, primal_optim_cls): dual_optim_cls=cooper.optim.ExtraSGD, use_ineq=True, use_proxy_ineq=False, + use_mult_model=False, dual_restarts=False, alternating=False, ) @@ -49,7 +51,7 @@ def test_extrapolation(aim_device, primal_optimizer_cls): assert cmp.state.eq_defect is None or cmp.state.eq_defect.is_cuda assert cmp.state.ineq_defect is None or cmp.state.ineq_defect.is_cuda - # TODO: Why do we need such relaxed tolerance for this test to pass? + # TODO(gallego-posada): Why do we need such relaxed tolerance for this test to pass? if primal_optimizer_cls == cooper.optim.ExtraSGD: atol = 1e-8 else: diff --git a/tests/test_lagrangian_formulation.py b/tests/test_lagrangian_formulation.py index 0409952b..ed483700 100644 --- a/tests/test_lagrangian_formulation.py +++ b/tests/test_lagrangian_formulation.py @@ -2,10 +2,21 @@ """Tests for Lagrangian Formulation class.""" +import random + +import numpy as np +import pytest import torch import cooper +from .helpers import cooper_test_utils + +random.seed(121212) +np.random.seed(121212) +torch.manual_seed(121212) +torch.cuda.manual_seed(121211) + def test_lagrangian_formulation(): class DummyCMP(cooper.ConstrainedMinimizationProblem): @@ -28,3 +39,42 @@ def closure(self): ) lf.create_state(cmp.state) assert (lf.ineq_multipliers is not None) and (lf.eq_multipliers is not None) + + +@pytest.mark.parametrize("aim_device", ["cpu", "cuda"]) +def test_convergence_small_toy_problem(aim_device): + + test_problem_data = cooper_test_utils.build_test_problem( + aim_device=aim_device, + primal_optim_cls=torch.optim.SGD, + primal_init=[0.0, -1.0], + dual_optim_cls=torch.optim.SGD, + use_ineq=True, + use_proxy_ineq=False, + use_mult_model=False, + dual_restarts=False, + alternating=True, + primal_optim_kwargs={"lr": 1e-2}, + dual_optim_kwargs={"lr": 1.0}, + dual_scheduler=None, + formulation_cls=cooper.formulation.LagrangianFormulation, + ) + + params, cmp, coop, formulation, device, mktensor = test_problem_data.as_tuple() + + for step_id in range(400): + coop.zero_grad() + lagrangian = formulation.compute_lagrangian( + closure=cmp.closure, + params=params, + ) + formulation.backward(lagrangian) + coop.step(cmp.closure, params) + + if device == "cuda": + assert cmp.state.loss.is_cuda + assert cmp.state.eq_defect is None or cmp.state.eq_defect.is_cuda + assert cmp.state.ineq_defect is None or cmp.state.ineq_defect.is_cuda + + assert torch.allclose(params[0], mktensor(2.0 / 3.0), atol=1e-3) + assert torch.allclose(params[1], mktensor(1.0 / 3.0), atol=1e-3) diff --git a/tests/test_lagrangian_model.py b/tests/test_lagrangian_model.py new file mode 100644 index 00000000..cca1c514 --- /dev/null +++ b/tests/test_lagrangian_model.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python + +"""Tests for Lagrangian Model Formulation class.""" + +import random + +import numpy as np +import pytest +import torch + +import cooper + +from .helpers import cooper_test_utils + +random.seed(121212) +np.random.seed(121212) +torch.manual_seed(121212) +torch.cuda.manual_seed(121211) + + +class DummyCMP(cooper.ConstrainedMinimizationProblem): + def __init__(self): + super().__init__() + + def closure(self): + pass + + +class DummyMultiplierModel(cooper.multipliers.MultiplierModel): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 1) + + def forward(self, constraint_features): + return self.linear(constraint_features) + + +@pytest.fixture +def get_dummy_cmp(): + return DummyCMP() + + +@pytest.fixture +def get_ineq_multiplier_model(): + return DummyMultiplierModel() + + +@pytest.fixture +def get_eq_multiplier_model(): + return DummyMultiplierModel() + + +def return_fixture_values(cmp, ineq_multiplier_model, eq_multiplier_model, request): + cmp = request.getfixturevalue(cmp) + if ineq_multiplier_model is not None: + ineq_multiplier_model = request.getfixturevalue(ineq_multiplier_model) + if eq_multiplier_model is not None: + eq_multiplier_model = request.getfixturevalue(eq_multiplier_model) + return cmp, ineq_multiplier_model, eq_multiplier_model + + +@pytest.mark.parametrize("cmp", ["get_dummy_cmp"]) +@pytest.mark.parametrize("ineq_multiplier_model", [None, "get_ineq_multiplier_model"]) +@pytest.mark.parametrize("eq_multiplier_model", [None, "get_eq_multiplier_model"]) +def test_is_state_created(cmp, ineq_multiplier_model, eq_multiplier_model, request): + + cmp, ineq_multiplier_model, eq_multiplier_model = return_fixture_values( + cmp, ineq_multiplier_model, eq_multiplier_model, request + ) + + if ineq_multiplier_model is None and eq_multiplier_model is None: + with pytest.raises(ValueError): + cooper.formulation.LagrangianModelFormulation( + cmp, + ineq_multiplier_model=ineq_multiplier_model, + eq_multiplier_model=eq_multiplier_model, + ) + else: + lmf = cooper.formulation.LagrangianModelFormulation( + cmp, + ineq_multiplier_model=ineq_multiplier_model, + eq_multiplier_model=eq_multiplier_model, + ) + assert lmf.is_state_created + + +@pytest.mark.parametrize( + "cmp, ineq_multiplier_model, eq_multiplier_model", + [ + ("get_dummy_cmp", "get_ineq_multiplier_model", None), + ("get_dummy_cmp", None, "get_eq_multiplier_model"), + ("get_dummy_cmp", "get_ineq_multiplier_model", "get_eq_multiplier_model"), + ], +) +def test_state(cmp, ineq_multiplier_model, eq_multiplier_model, request): + + cmp, ineq_multiplier_model, eq_multiplier_model = return_fixture_values( + cmp, ineq_multiplier_model, eq_multiplier_model, request + ) + + lmf = cooper.formulation.LagrangianModelFormulation( + cmp, + ineq_multiplier_model=ineq_multiplier_model, + eq_multiplier_model=eq_multiplier_model, + ) + + ineq_state, eq_state = lmf.state() + + if ineq_multiplier_model is None: + assert ineq_state is None + else: + assert ineq_state is not None + + if eq_multiplier_model is None: + assert eq_state is None + else: + assert eq_state is not None + + +@pytest.mark.parametrize( + "cmp, ineq_multiplier_model, eq_multiplier_model", + [ + ("get_dummy_cmp", "get_ineq_multiplier_model", None), + ("get_dummy_cmp", None, "get_eq_multiplier_model"), + ("get_dummy_cmp", "get_ineq_multiplier_model", "get_eq_multiplier_model"), + ], +) +def test_flip_dual_gradients(cmp, ineq_multiplier_model, eq_multiplier_model, request): + + cmp, ineq_multiplier_model, eq_multiplier_model = return_fixture_values( + cmp, ineq_multiplier_model, eq_multiplier_model, request + ) + + lmf = cooper.formulation.LagrangianModelFormulation( + cmp, + ineq_multiplier_model=ineq_multiplier_model, + eq_multiplier_model=eq_multiplier_model, + ) + + for constraint_type in ["eq", "ineq"]: + mult_name = constraint_type + "_multiplier_model" + multiplier_model = getattr(lmf, mult_name) + if multiplier_model is not None: + for param in multiplier_model.parameters(): + param.requires_grad = True + param.grad = torch.ones_like(param) + + lmf.flip_dual_gradients() + + for constraint_type in ["eq", "ineq"]: + mult_name = constraint_type + "_multiplier_model" + multiplier_model = getattr(lmf, mult_name) + if multiplier_model is not None: + for param in multiplier_model.parameters(): + assert torch.all(param.grad == -1) + + +@pytest.mark.parametrize("aim_device", ["cpu", "cuda"]) +def test_convergence_small_toy_problem(aim_device): + + test_problem_data = cooper_test_utils.build_test_problem( + aim_device=aim_device, + primal_optim_cls=torch.optim.SGD, + primal_init=[0.0, -1.0], + dual_optim_cls=torch.optim.SGD, + use_ineq=True, + use_proxy_ineq=False, + use_mult_model=True, + dual_restarts=False, + alternating=False, + primal_optim_kwargs={"lr": 1e-2}, + dual_optim_kwargs={"lr": 1e-2}, + dual_scheduler=None, + formulation_cls=cooper.formulation.LagrangianModelFormulation, + ) + + params, cmp, coop, formulation, device, mktensor = test_problem_data.as_tuple() + + coop.instantiate_dual_optimizer_and_scheduler() + + for _ in range(400): + coop.zero_grad() + + lagrangian = formulation.compute_lagrangian( + closure=cmp.closure, + params=params, + ) + formulation.backward(lagrangian) + coop.step() + + if device == "cuda": + assert cmp.state.loss.is_cuda + assert cmp.state.eq_defect is None or cmp.state.eq_defect.is_cuda + assert cmp.state.ineq_defect is None or cmp.state.ineq_defect.is_cuda + + assert torch.allclose(params[0], mktensor(2.0 / 3.0), atol=1e-2) + assert torch.allclose(params[1], mktensor(1.0 / 3.0), atol=1e-2) diff --git a/tests/test_lr_schedulers.py b/tests/test_lr_schedulers.py index 71b1818f..5a2515a9 100644 --- a/tests/test_lr_schedulers.py +++ b/tests/test_lr_schedulers.py @@ -2,12 +2,13 @@ """Tests for LR schedulers.""" -import cooper_test_utils import pytest import torch import cooper +from .helpers import cooper_test_utils + @pytest.mark.parametrize("aim_device", ["cpu", "cuda"]) @pytest.mark.parametrize("scheduler_name", ["ExponentialLR", "ConstantLR"]) @@ -40,6 +41,7 @@ def test_lr_schedulers(aim_device, scheduler_name, optimizer_cls): dual_optim_cls=optimizer_cls, use_ineq=True, use_proxy_ineq=True, + use_mult_model=False, dual_restarts=False, alternating=False, primal_optim_kwargs={"lr": base_lr}, diff --git a/tests/test_multiplier_model.py b/tests/test_multiplier_model.py new file mode 100644 index 00000000..302f1ce4 --- /dev/null +++ b/tests/test_multiplier_model.py @@ -0,0 +1,15 @@ +import cooper + + +def test_multiplier_model_init(): + class DummyMultiplierModel(cooper.multipliers.MultiplierModel): + def __init__(self): + super().__init__() + + def forward(self): + pass + + multiplier_model = DummyMultiplierModel() + + assert isinstance(multiplier_model, cooper.multipliers.BaseMultiplier) + assert isinstance(multiplier_model, cooper.multipliers.MultiplierModel) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index d0b2f7b2..502644ef 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -3,10 +3,11 @@ """Tests for Constrained Optimizer class. This test already verifies that the code behaves as expected for an unconstrained setting.""" -import cooper_test_utils import pytest import torch +from .helpers import cooper_test_utils + @pytest.mark.parametrize("aim_device", ["cpu", "cuda"]) @pytest.mark.parametrize("use_ineq", [True, False]) @@ -30,6 +31,7 @@ def test_toy_problem(aim_device, use_ineq, multiple_optimizers): dual_optim_cls=torch.optim.SGD, use_ineq=use_ineq, use_proxy_ineq=False, + use_mult_model=False, dual_restarts=True, alternating=False, primal_optim_kwargs=primal_optim_kwargs, diff --git a/tests/test_proxy_constraints.py b/tests/test_proxy_constraints.py index 1ae449ac..db717a36 100644 --- a/tests/test_proxy_constraints.py +++ b/tests/test_proxy_constraints.py @@ -2,10 +2,11 @@ """Tests for Constrained Optimizer class.""" -import cooper_test_utils import pytest import torch +from .helpers import cooper_test_utils + @pytest.mark.parametrize("aim_device", ["cpu", "cuda"]) def test_manual_proxy_constraints(aim_device): @@ -21,6 +22,7 @@ def test_manual_proxy_constraints(aim_device): dual_optim_cls=torch.optim.SGD, use_ineq=True, use_proxy_ineq=True, + use_mult_model=False, dual_restarts=False, alternating=False, primal_optim_kwargs={"lr": 5e-2},