Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin' into fix_windows_timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
esantorella committed Jul 26, 2023
2 parents 5ce0fc3 + 4871a9a commit 918635a
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
5 changes: 5 additions & 0 deletions botorch/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import logging

import torch

LOG_LEVEL_DEFAULT = logging.CRITICAL

Expand Down Expand Up @@ -36,4 +37,8 @@ def _get_logger(
return logger


def shape_to_str(shape: torch.Size) -> str:
return f"`{' x '.join(str(i) for i in shape)}`"


logger = _get_logger()
5 changes: 3 additions & 2 deletions botorch/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import torch
from botorch import settings
from botorch.exceptions.errors import BotorchTensorDimensionError, InputDataError
from botorch.logging import shape_to_str
from botorch.models.utils.assorted import fantasize as fantasize_flag
from botorch.posteriors import Posterior, PosteriorList
from botorch.sampling.base import MCSampler
Expand Down Expand Up @@ -581,8 +582,8 @@ def fantasize(
):
raise BotorchTensorDimensionError(
f"Expected evaluation_mask of shape `{X.shape[0]} "
f"x {self.num_outputs}`, but got `"
f"{' x '.join(str(i) for i in evaluation_mask.shape)}`."
f"x {self.num_outputs}`, but got "
f"{shape_to_str(evaluation_mask.shape)}."
)
if not isinstance(sampler, ListSampler):
raise ValueError("Decoupled fantasization requires a list of samplers.")
Expand Down
10 changes: 9 additions & 1 deletion test/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

import logging

import torch

from botorch import settings
from botorch.logging import LOG_LEVEL_DEFAULT, logger
from botorch.logging import LOG_LEVEL_DEFAULT, logger, shape_to_str
from botorch.utils.testing import BotorchTestCase


Expand All @@ -31,3 +33,9 @@ def test_settings_log_level(self):
self.assertEqual(logger.level, logging.INFO)
# Finally, verify the original level is set again
self.assertEqual(logger.level, LOG_LEVEL_DEFAULT)

def test_shape_to_str(self):
self.assertEqual("``", shape_to_str(torch.Size([])))
self.assertEqual("`1`", shape_to_str(torch.Size([1])))
self.assertEqual("`1 x 2`", shape_to_str(torch.Size([1, 2])))
self.assertEqual("`1 x 2 x 3`", shape_to_str(torch.Size([1, 2, 3])))

0 comments on commit 918635a

Please sign in to comment.