Skip to content

Commit

Permalink
Add matmul ufunc with support for @ operator
Browse files Browse the repository at this point in the history
Adds support for __matmul__ and __rmatmul__ (@) operator, by way of
numpy matmul ufunc (only supported in Python >= 3.5). Includes
workaround for np.matmul returning units in numpy > 1.15 and not in
numpy <= 1.15.

Closes #859
  • Loading branch information
jthielen committed Aug 31, 2019
1 parent 55058bc commit 11ddbfa
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
19 changes: 17 additions & 2 deletions pint/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,21 @@ def __mul__(self, other):

__rmul__ = __mul__

def __matmul__(self, other):
try:
ret = np.matmul(self, other)
if not self._check(ret):
if self._check(other):
units = self._units * other._units
else:
units = self._units
ret = self.__class__(ret, units)
return ret
except (AttributeError, TypeError):
return NotImplemented

__rmatmul__ = __matmul__

def __itruediv__(self, other):
if not isinstance(self._magnitude, ndarray):
return self._mul_div(other, operator.truediv)
Expand Down Expand Up @@ -1317,10 +1332,10 @@ def __bool__(self):
#: will be raised.
__prod_units = {'var': 2, 'prod': 'size', 'multiply': 'mul',
'true_divide': 'div', 'divide': 'div', 'floor_divide': 'div',
'remainder': 'div',
'matmul': 'mul', 'remainder': 'div',
'sqrt': .5, 'square': 2, 'reciprocal': -1}

__skip_other_args = 'ldexp multiply ' \
__skip_other_args = 'ldexp multiply matmul ' \
'true_divide divide floor_divide fmod mod ' \
'remainder'.split()

Expand Down
20 changes: 20 additions & 0 deletions pint/testsuite/test_issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,3 +679,23 @@ def pendulum_period(length, G=Q_(1, 'standard_gravity')):
def test_issue783(self):
ureg = UnitRegistry()
assert not ureg('g') == []

@helpers.requires_numpy()
@unittest.skipIf(sys.version_info < (3, 5), 'Requires Python >= 3.5')
def test_issue859_with_numpy(self):
import operator as op
ureg = UnitRegistry()
A = [[1, 2], [3, 4]] * ureg.m
B = np.array([[0, -1], [-1, 0]])
b = [[1], [0]] * ureg.m
self.assertQuantityEqual(op.matmul(A, B), [[-2, -1], [-4, -3]] * ureg.m)
self.assertQuantityEqual(op.matmul(A, b), [[1], [3]] * ureg.m**2)
self.assertQuantityEqual(op.matmul(B, b), [[0], [-1]] * ureg.m)

@helpers.requires_not_numpy()
@unittest.skipIf(sys.version_info < (3, 5), 'Requires Python >= 3.5')
def test_issue859_without_numpy(self):
import operator as op
ureg = UnitRegistry()
A = [[1, 2], [3, 4]] * ureg.m
self.assertRaises(TypeError, op.matmul, A, A)

0 comments on commit 11ddbfa

Please sign in to comment.