Skip to content

Commit

Permalink
Add _NOT_IMPLEMENTED attribute to jax.numpy (fixes jax-ml#3689)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jul 9, 2020
1 parent c1aeb8b commit 2a0fcbc
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 6 deletions.
3 changes: 2 additions & 1 deletion jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
tan, tanh, tensordot, tile, trace, trapz, transpose, tri, tril, tril_indices, tril_indices_from,
triu, triu_indices, triu_indices_from, true_divide, trunc, uint16, uint32, uint64, uint8, unique,
unpackbits, unravel_index, unsignedinteger, unwrap, vander, var, vdot, vsplit,
vstack, where, zeros, zeros_like)
vstack, where, zeros, zeros_like, _NOT_IMPLEMENTED)

from .polynomial import roots
from .vectorize import vectorize
Expand All @@ -73,6 +73,7 @@ def _init():
# Builds a set of all unimplemented NumPy functions.
for name, func in util.get_module_functions(np).items():
if name not in globals():
_NOT_IMPLEMENTED.append(name)
globals()[name] = lax_numpy._not_implemented(func)

_init()
Expand Down
2 changes: 2 additions & 0 deletions jax/numpy/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ def ifftshift(x, axes=None):
return jnp.roll(x, shift, axes)


_NOT_IMPLEMENTED = []
for name, func in get_module_functions(np.fft).items():
if name not in globals():
_NOT_IMPLEMENTED.append(name)
globals()[name] = _not_implemented(func)
1 change: 1 addition & 0 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4326,6 +4326,7 @@ def _operator_round(number, ndigits=None):
# _not_implemented implementations of them here rather than in __init__.py.
# TODO(phawkins): implement these.
argpartition = _not_implemented(np.argpartition)
_NOT_IMPLEMENTED = ['argpartition']

# Set up operator, method, and property forwarding on Tracer instances containing
# ShapedArray avals by following the forwarding conventions for Tracer.
Expand Down
12 changes: 7 additions & 5 deletions jax/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,11 +479,6 @@ def solve(a, b):
return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)


for name, func in get_module_functions(np.linalg).items():
if name not in globals():
globals()[name] = _not_implemented(func)


@_wraps(np.linalg.lstsq, lax_description=textwrap.dedent("""\
It has two important differences:
Expand Down Expand Up @@ -535,3 +530,10 @@ def lstsq(a, b, rcond=None, *, numpy_resid=False):
if b_orig_ndim == 1:
x = x.ravel()
return x, resid, rank, s


_NOT_IMPLEMENTED = []
for name, func in get_module_functions(np.linalg).items():
if name not in globals():
_NOT_IMPLEMENTED.append(name)
globals()[name] = _not_implemented(func)
8 changes: 8 additions & 0 deletions jax/numpy/polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ._util import _wraps
from .linalg import eigvals as _eigvals
from .. import ops as jaxops
from ..util import get_module_functions


def _to_inexact_type(type):
Expand Down Expand Up @@ -102,3 +103,10 @@ def roots(p, *, strip_zeros=True):
# combine roots and zero roots
roots = jnp.hstack((roots, jnp.zeros(trailing_zeros, p.dtype)))
return roots


_NOT_IMPLEMENTED = []
for name, func in get_module_functions(np.polynomial).items():
if name not in globals():
_NOT_IMPLEMENTED.append(name)
globals()[name] = _not_implemented(func)
6 changes: 6 additions & 0 deletions tests/fft_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ def _zero_for_irfft(z, axes):

class FftTest(jtu.JaxTestCase):

def testNotImplemented(self):
for name in jnp.fft._NOT_IMPLEMENTED:
func = getattr(jnp.fft, name)
with self.assertRaises(NotImplementedError):
func()

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inverse={}_real={}_shape={}_axes={}".format(
inverse, real, jtu.format_shape_dtype_string(shape, dtype), axes),
Expand Down
6 changes: 6 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,12 @@ def f():
for a in out]
return f

def testNotImplemented(self):
for name in jnp._NOT_IMPLEMENTED:
func = getattr(jnp, name)
with self.assertRaises(NotImplementedError):
func()

@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
Expand Down
6 changes: 6 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def _skip_if_unsupported_type(dtype):

class NumpyLinalgTest(jtu.JaxTestCase):

def testNotImplemented(self):
for name in jnp.linalg._NOT_IMPLEMENTED:
func = getattr(jnp.linalg, name)
with self.assertRaises(NotImplementedError):
func()

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
Expand Down
6 changes: 6 additions & 0 deletions tests/polynomial_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@

class TestPolynomial(jtu.JaxTestCase):

def testNotImplemented(self):
for name in jnp.polynomial._NOT_IMPLEMENTED:
func = getattr(jnp.polynomial, name)
with self.assertRaises(NotImplementedError):
func()

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}_leading={}_trailing={}".format(
jtu.format_shape_dtype_string((length+leading+trailing,), dtype),
Expand Down

0 comments on commit 2a0fcbc

Please sign in to comment.