Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise ValueError on ambiguous boolean conversion and add pending NumPy functions/ufuncs from issue tracker #965

Merged
merged 2 commits into from
Dec 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ Pint Changelog
0.10 (unreleased)
-----------------

- **BREAKING CHANGE**:
Boolean value of Quantities with offsets units is ambiguous, and so, now a ValueError
is raised when attempting to cast such a Quantity to boolean.
(Issue #965, Thanks Jon Thielen)
- Documentation on Pint's array type compatibility has been added to the NumPy support
page, including a graph of the duck array type casting hierarchy as understood by Pint
for N-dimensional arrays.
Expand Down
4 changes: 2 additions & 2 deletions docs/numpy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@
"\n",
"The following [ufuncs](http://docs.scipy.org/doc/numpy/reference/ufuncs.html) can be applied to a Quantity object:\n",
"\n",
"- **Math operations**: `add`, `subtract`, `multiply`, `divide`, `logaddexp`, `logaddexp2`, `true_divide`, `floor_divide`, `negative`, `remainder`, `mod`, `fmod`, `absolute`, `rint`, `sign`, `conj`, `exp`, `exp2`, `log`, `log2`, `log10`, `expm1`, `log1p`, `sqrt`, `square`, `reciprocal`\n",
"- **Math operations**: `add`, `subtract`, `multiply`, `divide`, `logaddexp`, `logaddexp2`, `true_divide`, `floor_divide`, `negative`, `remainder`, `mod`, `fmod`, `absolute`, `rint`, `sign`, `conj`, `exp`, `exp2`, `log`, `log2`, `log10`, `expm1`, `log1p`, `sqrt`, `square`, `cbrt`, `reciprocal`\n",
"- **Trigonometric functions**: `sin`, `cos`, `tan`, `arcsin`, `arccos`, `arctan`, `arctan2`, `hypot`, `sinh`, `cosh`, `tanh`, `arcsinh`, `arccosh`, `arctanh`\n",
"- **Comparison functions**: `greater`, `greater_equal`, `less`, `less_equal`, `not_equal`, `equal`\n",
"- **Floating functions**: `isreal`, `iscomplex`, `isfinite`, `isinf`, `isnan`, `signbit`, `copysign`, `nextafter`, `modf`, `ldexp`, `frexp`, `fmod`, `floor`, `ceil`, `trunc`\n",
Expand All @@ -301,7 +301,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"['alen', 'amax', 'amin', 'append', 'argmax', 'argmin', 'argsort', 'around', 'atleast_1d', 'atleast_2d', 'atleast_3d', 'average', 'block', 'broadcast_to', 'clip', 'column_stack', 'compress', 'concatenate', 'copy', 'copyto', 'count_nonzero', 'cross', 'cumprod', 'cumproduct', 'cumsum', 'diagonal', 'diff', 'dot', 'dstack', 'ediff1d', 'einsum', 'empty_like', 'expand_dims', 'fix', 'flip', 'full_like', 'gradient', 'hstack', 'insert', 'interp', 'isclose', 'iscomplex', 'isin', 'isreal', 'linspace', 'mean', 'median', 'meshgrid', 'moveaxis', 'nan_to_num', 'nanargmax', 'nanargmin', 'nancumprod', 'nancumsum', 'nanmax', 'nanmean', 'nanmedian', 'nanmin', 'nanpercentile', 'nanstd', 'nansum', 'nanvar', 'ndim', 'nonzero', 'ones_like', 'pad', 'percentile', 'ptp', 'ravel', 'resize', 'result_type', 'rollaxis', 'rot90', 'round_', 'searchsorted', 'shape', 'size', 'sort', 'squeeze', 'stack', 'std', 'sum', 'swapaxes', 'tile', 'transpose', 'trapz', 'trim_zeros', 'unwrap', 'var', 'vstack', 'where', 'zeros_like']\n"
"['alen', 'all', 'amax', 'amin', 'any', 'append', 'argmax', 'argmin', 'argsort', 'around', 'atleast_1d', 'atleast_2d', 'atleast_3d', 'average', 'block', 'broadcast_to', 'clip', 'column_stack', 'compress', 'concatenate', 'copy', 'copyto', 'count_nonzero', 'cross', 'cumprod', 'cumproduct', 'cumsum', 'diagonal', 'diff', 'dot', 'dstack', 'ediff1d', 'einsum', 'empty_like', 'expand_dims', 'fix', 'flip', 'full_like', 'gradient', 'hstack', 'insert', 'interp', 'isclose', 'iscomplex', 'isin', 'isreal', 'linalg.solve', 'linspace', 'mean', 'median', 'meshgrid', 'moveaxis', 'nan_to_num', 'nanargmax', 'nanargmin', 'nancumprod', 'nancumsum', 'nanmax', 'nanmean', 'nanmedian', 'nanmin', 'nanpercentile', 'nanstd', 'nansum', 'nanvar', 'ndim', 'nonzero', 'ones_like', 'pad', 'percentile', 'ptp', 'ravel', 'resize', 'result_type', 'rollaxis', 'rot90', 'round_', 'searchsorted', 'shape', 'size', 'sort', 'squeeze', 'stack', 'std', 'sum', 'swapaxes', 'tile', 'transpose', 'trapz', 'trim_zeros', 'unwrap', 'var', 'vstack', 'where', 'zeros_like']\n"
]
}
],
Expand Down
38 changes: 35 additions & 3 deletions pint/numpy_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ def get_op_output_unit(unit_op, first_input_units, all_args=None, size=None):
result_unit = first_input_units ** 2
elif unit_op == "sqrt":
result_unit = first_input_units ** 0.5
elif unit_op == "cbrt":
result_unit = first_input_units ** (1 / 3)
elif unit_op == "reciprocal":
result_unit = first_input_units ** -1
elif unit_op == "size":
Expand Down Expand Up @@ -255,7 +257,11 @@ def implement_func(func_type, func_str, input_units=None, output_unit=None):
if np is None:
return

func = getattr(np, func_str)
# Handle functions in submodules
func_str_split = func_str.split(".")
func = getattr(np, func_str_split[0])
for func_str_piece in func_str_split[1:]:
func = getattr(func, func_str_piece)

@implements(func_str, func_type)
def implementation(*args, **kwargs):
Expand Down Expand Up @@ -295,6 +301,7 @@ def implementation(*args, **kwargs):
"variance",
"square",
"sqrt",
"cbrt",
"reciprocal",
"size",
]:
Expand Down Expand Up @@ -408,6 +415,7 @@ def implementation(*args, **kwargs):
"divide": "div",
"floor_divide": "div",
"sqrt": "sqrt",
"cbrt": "cbrt",
"square": "square",
"reciprocal": "reciprocal",
"std": "sum",
Expand Down Expand Up @@ -665,6 +673,24 @@ def _recursive_convert(arg, unit):
)


@implements("any", "function")
def _any(a, *args, **kwargs):
# Only valid when multiplicative unit/no offset
if a._is_multiplicative:
return np.any(a._magnitude, *args, **kwargs)
else:
raise ValueError("Boolean value of Quantity with offset unit is ambiguous.")


@implements("all", "function")
def _all(a, *args, **kwargs):
# Only valid when multiplicative unit/no offset
if a._is_multiplicative:
return np.all(a._magnitude, *args, **kwargs)
else:
raise ValueError("Boolean value of Quantity with offset unit is ambiguous.")


# Implement simple matching-unit or stripped-unit functions based on signature


Expand Down Expand Up @@ -836,6 +862,8 @@ def implementation(a, *args, **kwargs):
implement_func("function", func_str, input_units=None, output_unit="delta")
for func_str in ["gradient"]:
implement_func("function", func_str, input_units=None, output_unit="delta,div")
for func_str in ["linalg.solve"]:
implement_func("function", func_str, input_units=None, output_unit="div")
for func_str in ["var", "nanvar"]:
implement_func("function", func_str, input_units=None, output_unit="variance")

Expand All @@ -846,11 +874,15 @@ def numpy_wrap(func_type, func, args, kwargs, types):

if func_type == "function":
handled = HANDLED_FUNCTIONS
# Need to handle functions in submodules
name = ".".join(func.__module__.split(".")[1:] + [func.__name__])
elif func_type == "ufunc":
handled = HANDLED_UFUNCS
# ufuncs do not have func.__module__
name = func.__name__
else:
raise ValueError("Invalid func_type {}".format(func_type))

if func.__name__ not in handled or any(is_upcast_type(t) for t in types):
if name not in handled or any(is_upcast_type(t) for t in types):
return NotImplemented
return handled[func.__name__](*args, **kwargs)
return handled[name](*args, **kwargs)
6 changes: 5 additions & 1 deletion pint/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,7 +1481,11 @@ def compare(self, other, op):
__gt__ = lambda self, other: self.compare(other, op=operator.gt)

def __bool__(self):
return bool(self._magnitude)
# Only cast when non-ambiguous (when multiplicative unit)
if self._is_multiplicative:
return bool(self._magnitude)
else:
raise ValueError("Boolean value of Quantity with offset unit is ambiguous.")

__nonzero__ = __bool__

Expand Down
27 changes: 27 additions & 0 deletions pint/testsuite/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,13 @@ def test_einsum(self):
np.array([30, 80, 130, 180, 230]) * self.ureg.m ** 2,
)

@helpers.requires_array_function_protocol()
def test_solve(self):
self.assertQuantityAlmostEqual(
np.linalg.solve(self.q, [[3], [7]] * self.ureg.s),
self.Q_([[1], [1]], "m / s"),
)

# Arithmetic operations
def test_addition_with_scalar(self):
a = np.array([0, 1, 2])
Expand Down Expand Up @@ -414,6 +421,14 @@ def test_power(self):
)
self.assertNDArrayEqual(arr ** self.Q_(2), np.array([0, 1, 4]))

def test_sqrt(self):
q = self.Q_(100, "m**2")
self.assertQuantityEqual(np.sqrt(q), self.Q_(10, "m"))

def test_cbrt(self):
q = self.Q_(1000, "m**3")
self.assertQuantityEqual(np.cbrt(q), self.Q_(10, "m"))

@unittest.expectedFailure
@helpers.requires_numpy()
def test_exponentiation_array_exp_2(self):
Expand Down Expand Up @@ -537,6 +552,18 @@ def test_nonzero_numpy_func(self):
q = [1, 0, 5, 6, 0, 9] * self.ureg.m
self.assertNDArrayEqual(np.nonzero(q)[0], [0, 2, 3, 5])

@helpers.requires_array_function_protocol()
def test_any_numpy_func(self):
q = [0, 1] * self.ureg.m
self.assertTrue(np.any(q))
self.assertRaises(ValueError, np.any, self.q_temperature)

@helpers.requires_array_function_protocol()
def test_all_numpy_func(self):
q = [0, 1] * self.ureg.m
self.assertFalse(np.all(q))
self.assertRaises(ValueError, np.all, self.q_temperature)

@helpers.requires_array_function_protocol()
def test_count_nonzero_numpy_func(self):
q = [1, 0, 5, 6, 0, 9] * self.ureg.m
Expand Down
2 changes: 2 additions & 0 deletions pint/testsuite/test_quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def test_quantity_bool(self):
self.assertTrue(self.Q_(1, "meter"))
self.assertFalse(self.Q_(0, None))
self.assertFalse(self.Q_(0, "meter"))
self.assertRaises(ValueError, bool, self.Q_(0, "degC"))
self.assertFalse(self.Q_(0, "delta_degC"))

def test_quantity_comparison(self):
x = self.Q_(4.2, "meter")
Expand Down