Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved error messages from wraps_ufunc #108

Merged
merged 10 commits into from
Sep 19, 2024
8 changes: 8 additions & 0 deletions cfspopcon/algorithm_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def from_single_function(
skip_registration: bool = False,
) -> Algorithm:
"""Build an Algorithm which wraps a single function."""
if not isinstance(return_keys, list):
return_keys = [return_keys]

@wraps(func)
def wrapped_function(**kwargs: Any) -> dict:
Expand Down Expand Up @@ -310,6 +312,12 @@ def wrapper(**kwargs: Any) -> xr.Dataset:
self._name = name
self.__doc__ = self._make_docstring()

@classmethod
def from_list(cls, keys: list[str]) -> CompositeAlgorithm:
"""Build a CompositeAlgorithm from a list of Algorithm names."""
algorithms = [Algorithm.get_algorithm(key) for key in keys]
return CompositeAlgorithm(algorithms=algorithms)

def _make_docstring(self) -> str:
"""Makes a doc-string detailing the function inputs and outputs."""
components = f"[{', '.join(alg._name for alg in self.algorithms)}]"
Expand Down
17 changes: 11 additions & 6 deletions cfspopcon/file_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import json
import sys
import warnings
from pathlib import Path
from typing import Any, Literal
from typing import Any, Literal, Union

if sys.version_info >= (3, 11, 0):
from typing import Self # type:ignore[attr-defined,unused-ignore]
Expand All @@ -23,18 +24,22 @@
]


def sanitize_variable(val: xr.DataArray, key: str) -> xr.DataArray:
def sanitize_variable(val: xr.DataArray, key: str) -> Union[xr.DataArray, str]:
"""Strip units and Enum values from a variable so that it can be stored in a NetCDF file."""
try:
val = convert_to_default_units(val, key).pint.dequantify()
except KeyError:
pass

if val.dtype == object:
if val.size == 1:
val = val.item().name
else:
val = xr.DataArray([v.name for v in val.values])
try:
if val.size == 1:
val = val.item().name
else:
val = xr.DataArray([v.name for v in val.values])
except AttributeError:
warnings.warn(f"Cannot handle {key}. Dropping variable.", stacklevel=3)
return "UNHANDLED"

return val

Expand Down
37 changes: 30 additions & 7 deletions cfspopcon/unit_handling/decorator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Defines the wraps_ufunc decorator used to perform unit conversions and dimension handling."""

import functools
import itertools
import warnings
from collections.abc import Callable, Mapping, Sequence, Set
from inspect import Parameter, Signature, signature
Expand Down Expand Up @@ -79,14 +80,21 @@ def wraps_ufunc( # noqa: PLR0915
else:
input_core_dims = len(pass_as_positional_args) * [()]

def _wraps_ufunc(func: FunctionType) -> FunctionType:
def _wraps_ufunc(func: FunctionType) -> FunctionType: # noqa: PLR0915
func_signature = signature(func)
func_parameters = func_signature.parameters

if not list(input_units.keys()) == list(func_parameters.keys()):
raise ValueError(
f"Keys for input_units {input_units.keys()} did not match func_parameters {func_parameters.keys()} (n.b. order matters!)"
)
message = f"Keys for input_units for {func.__name__} did not match the declared function inputs (n.b. order matters!)"
message += "\ni: input_key, func_param"
for i, (input_key, func_param) in enumerate(
itertools.zip_longest(list(input_units.keys()), list(func_parameters.keys()), fillvalue="MISSING")
):
message += f"\n{i}: {input_key}, {func_param}"
if not input_key == func_param:
message += " DOES NOT MATCH"

raise ValueError(message)

default_values = {key: val.default for key, val in func_parameters.items() if val.default is not Parameter.empty}

Expand Down Expand Up @@ -187,8 +195,14 @@ def _check_units(units_dict: dict[str, Union[str, Unit, None]]) -> dict[str, Uni


def _return_magnitude_in_specified_units(vals: Any, units_mapping: dict[str, Union[str, Unit, None]]) -> dict[str, Any]:
if not set(vals.keys()) == set(units_mapping):
raise ValueError(f"Incorrect input arguments: argument keys {vals.keys()} did not match units_mapping keys {units_mapping.keys()}")
vals_set, units_set = set(vals.keys()), set(units_mapping)
if not vals_set == units_set:
message = "Incorrect input arguments."
if vals_set - units_set:
message += f"\nUnused arguments given for: {vals_set - units_set}"
if units_set - vals_set:
message += f"\nMissing arguments for: {units_set - vals_set}"
raise ValueError(message)

converted_vals = {}

Expand All @@ -215,11 +229,20 @@ def _return_magnitude_in_specified_units(vals: Any, units_mapping: dict[str, Uni


def _convert_return_to_quantities(vals: Any, units_mapping: dict[str, Union[str, Unit, None]]) -> dict[str, Any]:
if isinstance(vals, xr.DataArray) and vals.ndim == 0:
# Calling wraps_ufunc with scalar values and multiple returns results in
# a xr.DataArray with a single tuple element.
tbody-cfs marked this conversation as resolved.
Show resolved Hide resolved
vals = vals.item()

if not isinstance(vals, tuple):
vals = (vals,)

if not len(vals) == len(units_mapping):
raise ValueError(f"Number of returned values ({len(vals)}) did not match length of units_mapping ({len(units_mapping)})")
message = f"Number of returned values ({len(vals)}) did not match length of units_mapping ({len(units_mapping)})"
message += "\ni: return_key, returned_value"
for i, (return_key, return_param) in enumerate(itertools.zip_longest(list(units_mapping.keys()), vals, fillvalue="MISSING")):
message += f"\n{i}: {return_key}, {return_param}"
raise ValueError(message)
vals = dict(zip(units_mapping.keys(), vals))

converted_vals = {}
Expand Down