-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Programmatically list the "Numpy" definitions not implemented by "jax.numpy". #3689
Comments
We could add some mechanism for that. Note that not-implemented functions are also defined elsewhere, for example: https://github.com/google/jax/blob/fdd7f0c8574b1e0e4366ee428153d73bc1788510/jax/numpy/lax_numpy.py#L4325-L4328 Would a global list of not implemented functions do what you want? For example, a list of strings stored in |
One way you can get the list in the current release is like this, though it's a bit of a hack: In [1] import inspect
...: import jax.numpy
...:
...: def unimplemented():
...: for name in dir(jax.numpy):
...: f = getattr(jax.numpy, name)
...: try:
...: source = inspect.getsource(f)
...: except (TypeError, OSError):
...: continue
...: if "Numpy function {} not yet implemented" in source:
...: yield name
...:
In [2]: list(unimplemented())
['<lambda>',
'_add_newdoc_ufunc',
'_fastCopyAndTranspose',
'add_docstring',
'add_newdoc',
'alen',
'apply_along_axis',
'apply_over_axes',
'argpartition',
'array2string',
'array_equiv',
'array_split',
'asanyarray',
'asarray_chkfinite',
'ascontiguousarray',
'asfarray',
'asfortranarray',
'asmatrix',
'asscalar',
'base_repr',
'binary_repr',
'bmat',
'busday_count',
'busday_offset',
'byte_bounds',
'choose',
'common_type',
'compare_chararrays',
'copy',
'copyto',
'datetime_as_string',
'datetime_data',
'delete',
'deprecate',
'diag_indices_from',
'disp',
'fill_diagonal',
'find_common_type',
'format_float_positional',
'format_float_scientific',
'frombuffer',
'fromfile',
'fromfunction',
'fromiter',
'frompyfunc',
'fromregex',
'fromstring',
'fv',
'genfromtxt',
'get_array_wrap',
'get_include',
'get_printoptions',
'getbufsize',
'geterr',
'geterrcall',
'geterrobj',
'histogram2d',
'histogramdd',
'i0',
'info',
'insert',
'int_asbuffer',
'interp',
'intersect1d',
'invert',
'ipmt',
'irr',
'is_busday',
'isfortran',
'isnat',
'issctype',
'issubclass_',
'lax_numpy',
'lexsort',
'loads',
'loadtxt',
'lookfor',
'mafromtxt',
'maximum_sctype',
'may_share_memory',
'min_scalar_type',
'mintypecode',
'mirr',
'modf',
'nanmedian',
'nanpercentile',
'nanquantile',
'ndfromtxt',
'nested_iters',
'nper',
'npv',
'obj2sctype',
'partition',
'piecewise',
'place',
'pmt',
'poly',
'polyder',
'polydiv',
'polyfit',
'polyint',
'ppmt',
'printoptions',
'put',
'put_along_axis',
'putmask',
'pv',
'rate',
'ravel_multi_index',
'real_if_close',
'recfromcsv',
'recfromtxt',
'require',
'resize',
'round_',
'safe_eval',
'savetxt',
'savez_compressed',
'sctype2char',
'set_numeric_ops',
'set_string_function',
'setbufsize',
'setdiff1d',
'seterr',
'seterrcall',
'seterrobj',
'setxor1d',
'shares_memory',
'show',
'sort_complex',
'source',
'spacing',
'tril_indices_from',
'trim_zeros',
'triu_indices_from',
'typename',
'union1d',
'unwrap',
'who'] |
Hi @jakevdp, Thanks, a simple attribute with a list of strings would be fantastic and super clean. Excellent idea about introspecting the code, I did not think about it but this will do for our tests! |
With #3697, you can now do this: import jax.numpy as jnp
print(jnp._NOT_IMPLEMENTED) The variable contains a list of names of unimplemented functions. |
Awesome @jakevdp and thanks again! |
Hi,
I'm trying to use Jax as a computation backend for Colour and I have no easy (and fast) way to programmatically find which Numpy definitions are supported or not by
jax.numpy
. The problem is that I can only discover that when they are called.What I was looking at currently is a mechanism that routes the definitions depending on whether the selected backend and failsafe to Numpy if they do not exist or are not implemented.
This is the relevant content of test
colour.ndarray.backend
module:Then the
colour.ndarray.__init__
module is implemented as follows:Thus now instead of
import numpy as np
I canimport colour.ndarray as np
and this route the code accordingly to the_NDIMENSIONAL_ARRAY_BACKEND
global.The problem is that if some of my code uses a Jax definition that is not implemented, e.g.
np.copy
it raises an exception.The list of not implemented definitions would be trivial to set somewhere when looking at the
jax.numpy.__init__
module here: https://github.com/google/jax/blob/a44bc0c2c05aa4a079eda3995379dab4a63182dc/jax/numpy/__init__.py#L76Hope that makes sense!
Cheers
Thomas
The text was updated successfully, but these errors were encountered: