diff --git a/botorch/generation/gen.py b/botorch/generation/gen.py index 0d70681fd5..35b8e87730 100644 --- a/botorch/generation/gen.py +++ b/botorch/generation/gen.py @@ -265,17 +265,21 @@ def f(x): # so it shouldn't be an issue given enough restarts. if nonlinear_inequality_constraints: for con, is_intrapoint in nonlinear_inequality_constraints: - if not nonlinear_constraint_is_feasible( - con, is_intrapoint=is_intrapoint, x=candidates - ): - candidates = torch.from_numpy(x0).to(candidates).reshape(shapeX) + if not ( + feasible := nonlinear_constraint_is_feasible( + con, is_intrapoint=is_intrapoint, x=candidates + ) + ).all(): + # Replace the infeasible batches with feasible ICs. + candidates[~feasible] = ( + torch.from_numpy(x0).to(candidates).reshape(shapeX)[~feasible] + ) warnings.warn( "SLSQP failed to converge to a solution the satisfies the " "non-linear constraints. Returning the feasible starting point.", OptimizationWarning, stacklevel=2, ) - break clamped_candidates = columnwise_clamp( X=candidates, lower=lower_bounds, upper=upper_bounds, raise_on_violation=True diff --git a/botorch/optim/parameter_constraints.py b/botorch/optim/parameter_constraints.py index e5719d7f63..069ad2f5e7 100644 --- a/botorch/optim/parameter_constraints.py +++ b/botorch/optim/parameter_constraints.py @@ -512,7 +512,7 @@ def f_grad(X): def nonlinear_constraint_is_feasible( nonlinear_inequality_constraint: Callable, is_intrapoint: bool, x: Tensor -) -> bool: +) -> Tensor: """Checks if a nonlinear inequality constraint is fulfilled. Args: @@ -522,23 +522,24 @@ def nonlinear_constraint_is_feasible( is applied pointwise and is broadcasted over the q-batch. Else, the constraint has to evaluated over the whole q-batch and is a an inter-point constraint. - x: Tensor of shape (b x q x d). + x: Tensor of shape (batch x q x d). Returns: - bool: True if the constraint is fulfilled, else False. + A boolean tensor of shape (batch) indicating if the constraint is + satified by the corresponding batch of `x`. """ def check_x(x: Tensor) -> bool: return _arrayify(nonlinear_inequality_constraint(x)).item() >= NLC_TOL - for x_ in x: + x_flat = x.view(-1, *x.shape[-2:]) + is_feasible = torch.ones(x_flat.shape[0], dtype=torch.bool, device=x.device) + for i, x_ in enumerate(x_flat): if is_intrapoint: - if not all(check_x(x__) for x__ in x_): - return False + is_feasible[i] &= all(check_x(x__) for x__ in x_) else: - if not check_x(x_): - return False - return True + is_feasible[i] &= check_x(x_) + return is_feasible.view(x.shape[:-2]) def make_scipy_nonlinear_inequality_constraints( @@ -589,7 +590,7 @@ def make_scipy_nonlinear_inequality_constraints( nlc, is_intrapoint = constraint if not nonlinear_constraint_is_feasible( nlc, is_intrapoint=is_intrapoint, x=x0.reshape(shapeX) - ): + ).all(): raise ValueError( "`batch_initial_conditions` must satisfy the non-linear inequality " "constraints." diff --git a/test/optim/test_parameter_constraints.py b/test/optim/test_parameter_constraints.py index 31752a085b..df0cad98b7 100644 --- a/test/optim/test_parameter_constraints.py +++ b/test/optim/test_parameter_constraints.py @@ -358,7 +358,7 @@ def nlc(x): ), ) ) - self.assertFalse( + self.assertEqual( nonlinear_constraint_is_feasible( nlc, True, @@ -366,7 +366,8 @@ def nlc(x): [[[1.5, 1.5], [1.5, 1.5]], [[1.5, 1.5], [1.5, 3.5]]], device=self.device, ), - ) + ).tolist(), + [True, False], ) self.assertTrue( nonlinear_constraint_is_feasible( @@ -381,14 +382,14 @@ def nlc(x): [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]], device=self.device, ), - ) + ).all() ) self.assertFalse( nonlinear_constraint_is_feasible( nlc, False, torch.tensor([[[1.5, 1.5], [1.5, 1.5]]], device=self.device) ) ) - self.assertFalse( + self.assertEqual( nonlinear_constraint_is_feasible( nlc, False, @@ -396,7 +397,8 @@ def nlc(x): [[[1.0, 1.0], [1.0, 1.0]], [[1.5, 1.5], [1.5, 1.5]]], device=self.device, ), - ) + ).tolist(), + [True, False], ) def test_generate_unfixed_nonlin_constraints(self):