diff --git a/pint/quantity.py b/pint/quantity.py index 3373552c7..5a074bedd 100644 --- a/pint/quantity.py +++ b/pint/quantity.py @@ -966,6 +966,21 @@ def __mul__(self, other): __rmul__ = __mul__ + def __matmul__(self, other): + try: + ret = np.matmul(self, other) + if not self._check(ret): + if self._check(other): + units = self._units * other._units + else: + units = self._units + ret = self.__class__(ret, units) + return ret + except (AttributeError, TypeError): + return NotImplemented + + __rmatmul__ = __matmul__ + def __itruediv__(self, other): if not isinstance(self._magnitude, ndarray): return self._mul_div(other, operator.truediv) @@ -1317,10 +1332,10 @@ def __bool__(self): #: will be raised. __prod_units = {'var': 2, 'prod': 'size', 'multiply': 'mul', 'true_divide': 'div', 'divide': 'div', 'floor_divide': 'div', - 'remainder': 'div', + 'matmul': 'mul', 'remainder': 'div', 'sqrt': .5, 'square': 2, 'reciprocal': -1} - __skip_other_args = 'ldexp multiply ' \ + __skip_other_args = 'ldexp multiply matmul ' \ 'true_divide divide floor_divide fmod mod ' \ 'remainder'.split() diff --git a/pint/testsuite/test_issues.py b/pint/testsuite/test_issues.py index 6e79601f4..7e9ca7bb9 100644 --- a/pint/testsuite/test_issues.py +++ b/pint/testsuite/test_issues.py @@ -679,3 +679,23 @@ def pendulum_period(length, G=Q_(1, 'standard_gravity')): def test_issue783(self): ureg = UnitRegistry() assert not ureg('g') == [] + + @helpers.requires_numpy() + @unittest.skipIf(sys.version_info < (3, 5), 'Requires Python >= 3.5') + def test_issue859_with_numpy(self): + import operator as op + ureg = UnitRegistry() + A = [[1, 2], [3, 4]] * ureg.m + B = np.array([[0, -1], [-1, 0]]) + b = [[1], [0]] * ureg.m + self.assertQuantityEqual(op.matmul(A, B), [[-2, -1], [-4, -3]] * ureg.m) + self.assertQuantityEqual(op.matmul(A, b), [[1], [3]] * ureg.m**2) + self.assertQuantityEqual(op.matmul(B, b), [[0], [-1]] * ureg.m) + + @helpers.requires_not_numpy() + @unittest.skipIf(sys.version_info < (3, 5), 'Requires Python >= 3.5') + def test_issue859_without_numpy(self): + import operator as op + ureg = UnitRegistry() + A = [[1, 2], [3, 4]] * ureg.m + self.assertRaises(TypeError, op.matmul, A, A)