Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POND-983: Add support for __torch_func__ [upstream] #35

Open
wants to merge 1 commit into
base: ponder-on-modin-0-19-0
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 50 additions & 6 deletions modin/numpy/arr.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,39 @@ def __array_function__(self, func, types, args, kwargs):
return NotImplemented
return modin_func(*args, **kwargs)

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
from . import array_creation as creation, array_shaping as shaping, math

func_name = func.__name__
modin_func = None
TORCH_ALIASES = {
"abs": math.absolute,
"add": math.add,
"mul": math.multiply,
"div": math.divide,
"sub": math.subtract,
"ge": cls.__ge__,
"gt": cls.__gt__,
"le": cls.__le__,
"lt": cls.__lt__,
"eq": cls.__eq__,
"ne": cls.__ne__,
}
if func_name in TORCH_ALIASES:
modin_func = TORCH_ALIASES[func_name]
elif hasattr(math, func_name):
modin_func = getattr(math, func_name)
elif hasattr(shaping, func_name):
modin_func = getattr(shaping, func_name)
elif hasattr(creation, func_name):
modin_func = getattr(creation, func_name)
if modin_func is None:
return NotImplemented
return modin_func(*args, **kwargs)

def where(self, x=None, y=None):
if not is_bool_dtype(self.dtype):
raise NotImplementedError(
Expand Down Expand Up @@ -716,8 +749,12 @@ def _binary_op(self, other):
raise ValueError(
f"operands could not be broadcast together with shapes {self.shape} {other.shape}"
)
return (caller, callee, caller._ndim, {"broadcast": broadcast, "axis": 1,
"sort_columns":False})
return (
caller,
callee,
caller._ndim,
{"broadcast": broadcast, "axis": 1, "sort_columns": False},
)
else:
if self.shape != other.shape:
# In this case, we either have two mismatched objects trying to do an operation
Expand All @@ -740,16 +777,23 @@ def _binary_op(self, other):
self,
other,
self._ndim,
{"broadcast": broadcast, "axis": matched_dimension,
"sort_columns":False},
{
"broadcast": broadcast,
"axis": matched_dimension,
"sort_columns": False,
},
)
else:
raise ValueError(
f"operands could not be broadcast together with shapes {self.shape} {other.shape}"
)
else:
return (self, other, self._ndim, {"broadcast": False,
"sort_columns":False})
return (
self,
other,
self._ndim,
{"broadcast": False, "sort_columns": False},
)

def _greater(
self,
Expand Down