Skip to content

Commit

Permalink
add utility for printing shapes as strings (pytorch#1947)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1947

This is useful for printing exception messages in a nice format that is easily verified with `assertRaisesRegex`.

Reviewed By: Balandat

Differential Revision: D47710862

fbshipit-source-id: 8f1a049fcd775889e086918ab8c9edf944fab8c6
  • Loading branch information
sdaulton authored and facebook-github-bot committed Jul 26, 2023
1 parent 8c763b3 commit 4871a9a
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
5 changes: 5 additions & 0 deletions botorch/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import logging

import torch

LOG_LEVEL_DEFAULT = logging.CRITICAL

Expand Down Expand Up @@ -36,4 +37,8 @@ def _get_logger(
return logger


def shape_to_str(shape: torch.Size) -> str:
return f"`{' x '.join(str(i) for i in shape)}`"


logger = _get_logger()
5 changes: 3 additions & 2 deletions botorch/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import torch
from botorch import settings
from botorch.exceptions.errors import BotorchTensorDimensionError, InputDataError
from botorch.logging import shape_to_str
from botorch.models.utils.assorted import fantasize as fantasize_flag
from botorch.posteriors import Posterior, PosteriorList
from botorch.sampling.base import MCSampler
Expand Down Expand Up @@ -581,8 +582,8 @@ def fantasize(
):
raise BotorchTensorDimensionError(
f"Expected evaluation_mask of shape `{X.shape[0]} "
f"x {self.num_outputs}`, but got `"
f"{' x '.join(str(i) for i in evaluation_mask.shape)}`."
f"x {self.num_outputs}`, but got "
f"{shape_to_str(evaluation_mask.shape)}."
)
if not isinstance(sampler, ListSampler):
raise ValueError("Decoupled fantasization requires a list of samplers.")
Expand Down
10 changes: 9 additions & 1 deletion test/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

import logging

import torch

from botorch import settings
from botorch.logging import LOG_LEVEL_DEFAULT, logger
from botorch.logging import LOG_LEVEL_DEFAULT, logger, shape_to_str
from botorch.utils.testing import BotorchTestCase


Expand All @@ -31,3 +33,9 @@ def test_settings_log_level(self):
self.assertEqual(logger.level, logging.INFO)
# Finally, verify the original level is set again
self.assertEqual(logger.level, LOG_LEVEL_DEFAULT)

def test_shape_to_str(self):
self.assertEqual("``", shape_to_str(torch.Size([])))
self.assertEqual("`1`", shape_to_str(torch.Size([1])))
self.assertEqual("`1 x 2`", shape_to_str(torch.Size([1, 2])))
self.assertEqual("`1 x 2 x 3`", shape_to_str(torch.Size([1, 2, 3])))

0 comments on commit 4871a9a

Please sign in to comment.