Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move some general utils here #51

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,21 @@ The package currently contains the following public functions and classes:
- `get_logger()`: convenience function to get (or create) a logger with given `name` as a child of the universal `astar` logger.
- `get_astar_logger()`: convenience function to get (or create) a logger with the name `astar`, which serves as the root for all A*V packages and applications.
- `SpectralType`: a class to parse, store and compare spectral type designations.
- `close_loop()`: an iterator function to add the first element back to the end.
- `stringify_dict()`: convert all non-primitive dict values to strings.
- `check_keys()`: check if required keys are present in dict and warn or raise otherwise.

### Loggers module

- `loggers.ColoredFormatter`: a subclass of `logging.Formatter` to produce colored logging messages for console output.

### Exceptions module

- `AstarWarning`: base class for warnings within the Astar ecosystem.
- `AstarUserWarning`: subclass of `AstarWarning` and the builtin `UserWarning`.

To be expanded with custom Error base classes for astar.

## Dependencies

Dependencies are intentionally kept to a minimum for simplicity. Current dependencies are:
Expand Down
2 changes: 2 additions & 0 deletions astar_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# -*- coding: utf-8 -*-

from .exceptions import AstarWarning, AstarUserWarning
from .nested_mapping import (NestedMapping, RecursiveNestedMapping,
NestedChainMap, is_bangkey, is_nested_mapping)
from .unique_list import UniqueList
from .badges import Badge, BadgeReport
from .loggers import get_logger, get_astar_logger
from .spectral_types import SpectralType
from .general_utils import close_loop, stringify_dict, check_keys
18 changes: 18 additions & 0 deletions astar_utils/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-
"""Exceptions (Errors and Warning) subclasses for use in astar projects.

Inspirations for AstarWarning and AstarUserWarning are taken from the same
concept in Astropy.
"""


class AstarWarning(Warning):
"""The base warning class from which all Astar warnings should inherit."""


class AstarUserWarning(UserWarning, AstarWarning):
"""
The primary warning class for Astar.

Use this if you do not need a specific sub-class.
"""
107 changes: 107 additions & 0 deletions astar_utils/general_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# -*- coding: utf-8 -*-
"""General utility function, mostly exported from ScopeSim."""

from warnings import warn
from collections.abc import Iterable, Generator, Set, Mapping

from .exceptions import AstarWarning, AstarUserWarning


def close_loop(iterable: Iterable) -> Generator:
"""
Add the first element of an iterable to the end again.

This is useful for e.g. plotting a closed shape from a list of points.

Parameters
----------
iterable : Iterable
Input iterable with n elements.

Yields
------
loop : Generator
Output iterable with n+1 elements.

Examples
--------
>>> x, y = [1, 2, 3], [4, 5, 6]
>>> x, y = zip(*close_loop(zip(x, y)))
>>> x
(1, 2, 3, 1)
>>> y
(4, 5, 6, 4)

"""
iterator = iter(iterable)
first = next(iterator)
yield first
yield from iterator
yield first


def stringify_dict(dic: Mapping, allowed_types=(str, int, float, bool)):
"""Turn a dict entries into strings for addition to FITS headers."""
for key, value in dic.items():
if isinstance(value, allowed_types):
yield key, value
else:
yield key, str(value)


def check_keys(
input_dict: Iterable,
required_keys: Set,
action: str = "error",
all_any: str = "all",
) -> bool:
"""
Check to see if all/any of the required keys are present in a dict.

Parameters
----------
input_dict : Union[Mapping, Iterable]
The mapping to be checked.
required_keys : Set
Set containing the keys to look for.
action : {"error", "warn", "warning"}, optional
What to do in case the check does not pass. The default is "error".
all_any : {"all", "any"}, optional
Whether to check if "all" or "any" of the `required_keys` are present.
The default is "all".

Raises
------
ValueError
Raised when an invalid parameter was passed or when `action` was set to
"error" (the default) and the `required_keys` were not found.

Returns
-------
keys_present : bool
``True`` if check succeded, ``False`` otherwise.

"""
# Checking for Set from collections.abc instead of builtin set to allow
# for any duck typing (e.g. dict keys view or whatever)
if not isinstance(required_keys, Set):
warn("required_keys should implement the Set protocol, found "
f"{type(required_keys)} instead.", AstarWarning)
required_keys = set(required_keys)

if all_any == "all":
keys_present = required_keys.issubset(input_dict)
elif all_any == "any":
keys_present = not required_keys.isdisjoint(input_dict)
else:
raise ValueError("all_any must be either 'all' or 'any'")

if not keys_present:
missing = "', '".join(required_keys.difference(input_dict)) or "<none>"
msg = f"The keys '{missing}' are missing from input_dict."
if "error" in action:
raise ValueError(msg)
if "warn" in action:
warn(msg, AstarUserWarning)

return keys_present
52 changes: 52 additions & 0 deletions tests/test_general_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
"""Unit tests for general_utils.py."""

import pytest

from astar_utils import close_loop, stringify_dict, check_keys, AstarWarning


def test_close_loop():
initer = [1, 2, 3]
outiter = close_loop(initer)
assert list(outiter) == initer + [1]


def test_stringify_dict():
indic = {"a": "foo", "b": 42, "c": [1, 2, 3], "d": True}
outdic = dict(stringify_dict(indic))
assert outdic["a"] == indic["a"]
assert outdic["b"] == indic["b"]
assert outdic["c"] == str(indic["c"])
assert outdic["d"] == indic["d"]


class TestCheckKeys:
@pytest.mark.filterwarnings("ignore::UserWarning")
@pytest.mark.parametrize(("req", "res"), [({"foo", "baz"}, True),
({"bogus", "baz"}, False)])
def test_warn_all(self, req, res):
tstdic = {"foo": 5, "bar": 2, "baz": 7}
assert check_keys(tstdic, req, action="warning") is res

@pytest.mark.filterwarnings("ignore::UserWarning")
@pytest.mark.parametrize(("req", "res"), [({"bogus", "baz"}, True),
({"bogus", "meh"}, False)])
def test_warn_any(self, req, res):
tstdic = {"foo": 5, "bar": 2, "baz": 7}
assert check_keys(tstdic, req, action="warning", all_any="any") is res

def test_raises_by_default(self):
tstdic = {"foo": 5, "bar": 2, "baz": 7}
with pytest.raises(ValueError):
check_keys(tstdic, {"bogus"})

def test_raises_for_invalid_any_all(self):
tstdic = {"foo": 5, "bar": 2, "baz": 7}
with pytest.raises(ValueError):
check_keys(tstdic, {"foo"}, all_any="bogus")

def test_warns_for_bad_type(self):
tstdic = {"a": 5, "b": 2, "c": 7}
with pytest.warns(AstarWarning):
check_keys(tstdic, "abc")