-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathufuncs.py
94 lines (77 loc) · 3.25 KB
/
ufuncs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""xarray specific universal functions
Handles unary and binary operations for the following types, in ascending
priority order:
- scalars
- numpy.ndarray
- dask.array.Array
- xarray.Variable
- xarray.DataArray
- xarray.Dataset
- xarray.core.groupby.GroupBy
Once NumPy 1.10 comes out with support for overriding ufuncs, this module will
hopefully no longer be necessary.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as _np
from .core.variable import Variable as _Variable
from .core.dataset import Dataset as _Dataset
from .core.dataarray import DataArray as _DataArray
from .core.groupby import GroupBy as _GroupBy
from .core.pycompat import dask_array_type as _dask_array_type
from .core.ops import _dask_or_eager_func
_xarray_types = (_Variable, _DataArray, _Dataset, _GroupBy)
_dispatch_order = (_np.ndarray, _dask_array_type) + _xarray_types
def _dispatch_priority(obj):
for priority, cls in enumerate(_dispatch_order):
if isinstance(obj, cls):
return priority
return -1
class _UFuncDispatcher(object):
"""Wrapper for dispatching ufuncs."""
def __init__(self, name):
self._name = name
def __call__(self, *args, **kwargs):
new_args = args
f = _dask_or_eager_func(self._name, n_array_args=len(args))
if len(args) > 2 or len(args) == 0:
raise TypeError('cannot handle %s arguments for %r' %
(len(args), self._name))
elif len(args) == 1:
if isinstance(args[0], _xarray_types):
f = args[0]._unary_op(self)
else: # len(args) = 2
p1, p2 = map(_dispatch_priority, args)
if p1 >= p2:
if isinstance(args[0], _xarray_types):
f = args[0]._binary_op(self)
else:
if isinstance(args[1], _xarray_types):
f = args[1]._binary_op(self, reflexive=True)
new_args = tuple(reversed(args))
res = f(*new_args, **kwargs)
if res is NotImplemented:
raise TypeError('%r not implemented for types (%r, %r)'
% (self._name, type(args[0]), type(args[1])))
return res
def _create_op(name):
func = _UFuncDispatcher(name)
func.__name__ = name
doc = getattr(_np, name).__doc__
func.__doc__ = ('xarray specific variant of numpy.%s. Handles '
'xarray.Dataset, xarray.DataArray, xarray.Variable, '
'numpy.ndarray and dask.array.Array objects with '
'automatic dispatching.\n\n'
'Documentation from numpy:\n\n%s' % (name, doc))
return func
__all__ = """logaddexp logaddexp2 conj exp log log2 log10 log1p expm1 sqrt
square sin cos tan arcsin arccos arctan arctan2 hypot sinh cosh
tanh arcsinh arccosh arctanh deg2rad rad2deg logical_and
logical_or logical_xor logical_not maximum minimum fmax fmin
isreal iscomplex isfinite isinf isnan signbit copysign nextafter
ldexp fmod floor ceil trunc degrees radians rint fix angle real
imag fabs sign frexp fmod
""".split()
for name in __all__:
globals()[name] = _create_op(name)