Skip to content

Commit

Permalink
remove mxnet for now
Browse files Browse the repository at this point in the history
  • Loading branch information
mattbarrett98 committed Sep 11, 2022
1 parent c8f4af4 commit e4d5c81
Show file tree
Hide file tree
Showing 39 changed files with 27 additions and 2,757 deletions.
6 changes: 2 additions & 4 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ Contents
Overview
--------

Ivy is an ML framework that currently supports JAX, TensorFlow, PyTorch, MXNet, and Numpy.
Ivy is an ML framework that currently supports JAX, TensorFlow, PyTorch, and Numpy.
We’re very excited for you to try it out!

Next on our roadmap is to support automatic code conversions between all frameworks 🔄,
Expand Down Expand Up @@ -144,7 +144,7 @@ You can immediately use Ivy to train a neural network, using your favorite frame
print('Finished training!')
This example uses PyTorch as a backend framework,
but the backend can easily be changed to your favorite frameworks, such as TensorFlow, JAX, or MXNet.
but the backend can easily be changed to your favorite frameworks, such as TensorFlow, or JAX.

**Framework Agnostic Functions**

Expand All @@ -156,15 +156,13 @@ This is the same for ALL Ivy functions. They can accept tensors from any framewo
import jax.numpy as jnp
import tensorflow as tf
import numpy as np
import mxnet as mx
import torch
import ivy
jax_concatted = ivy.concat((jnp.ones((1,)), jnp.ones((1,))), -1)
tf_concatted = ivy.concat((tf.ones((1,)), tf.ones((1,))), -1)
np_concatted = ivy.concat((np.ones((1,)), np.ones((1,))), -1)
mx_concatted = ivy.concat((mx.nd.ones((1,)), mx.nd.ones((1,))), -1)
torch_concatted = ivy.concat((torch.ones((1,)), torch.ones((1,))), -1)
To see a list of all Ivy methods, type :code:`ivy.` into a python command prompt and press :code:`tab`.
Expand Down
2 changes: 1 addition & 1 deletion docs/partial_source/background/standardization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Skepticism

With our central goal being to unify all ML frameworks, you would be entirely forgiven for raising an eyebrow 🤨

“You want to try and somehow unify: TensorFlow, PyTorch, JAX, NumPy, MXNet and others, all of which have strong industrial backing, huge user momentum, and significant API differences?”
“You want to try and somehow unify: TensorFlow, PyTorch, JAX, NumPy and others, all of which have strong industrial backing, huge user momentum, and significant API differences?”

Won’t adding a new “unified” framework just make the problem even worse…

Expand Down
11 changes: 0 additions & 11 deletions docs/partial_source/deep_dive/7_data_types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -270,17 +270,6 @@ Jax:
device: jaxlib.xla_extension.Device,
) -> JaxArray:
MXNet:

.. code-block:: python
def zeros(
shape: Union[int, Sequence[int]],
*,
dtype: type,
device: mx.context.Context,
) -> mx.nd.NDArray:
NumPy:

.. code-block:: python
Expand Down
11 changes: 0 additions & 11 deletions docs/partial_source/deep_dive/8_devices.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,17 +125,6 @@ Jax:
device: jaxlib.xla_extension.Device,
) -> JaxArray:
MXNet:

.. code-block:: python
def zeros(
shape: Union[int, Tuple[int], List[int]],
*,
dtype: type,
device: mx.context.Context,
) -> mx.nd.NDArray:
NumPy:

.. code-block:: python
Expand Down
28 changes: 0 additions & 28 deletions docs/partial_source/deep_dive/9_inplace_updates.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,27 +66,6 @@ JAX **does not** natively support inplace updates,
and so there is no way of actually inplace updating the :code:`JaxArray` instance :code:`x_native`.
Therefore, an inplace update is only performed on :code:`ivy.Array` instances provided in the input.

**MXNet**:

.. code-block:: python
def inplace_update(
x: Union[ivy.Array, mx.nd.NDArray],
val: Union[ivy.Array, mx.nd.NDArray],
ensure_in_backend: bool = False,
) -> ivy.Array:
(x_native, val_native), _ = ivy.args_to_native(x, val)
x_native[:] = val_native
if ivy.is_ivy_array(x):
x.data = x_native
else:
x = ivy.Array(x_native)
return x
MXNet **does** natively support inplace updates,
and so :code:`x_native` is updated inplace with :code:`val_native`.
Following this, an inplace update is then also performed on the :code:`ivy.Array` instance, if provided in the input.

**NumPy**:

.. code-block:: python
Expand Down Expand Up @@ -213,13 +192,6 @@ The implementations of :code:`ivy.tan` for each backend are as follows.
def tan(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jnp.tan(x)
**MXNet** (no :code:`support_native_out` attribute):

.. code-block:: python
def tan(x: mx.nd.NDArray, /, *, out: Optional[mx.nd.NDArray] = None) -> mx.nd.NDArray:
return mx.nd.tan(x)
**NumPy** (includes :code:`support_native_out` attribute):

.. code-block:: python
Expand Down
9 changes: 3 additions & 6 deletions ivy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import jax
import jaxlib
from jaxlib.xla_extension import Buffer
import mxnet as mx
import numpy as np
import tensorflow as tf
from tensorflow.python.types.core import Tensor
Expand All @@ -19,7 +18,7 @@

class FrameworkStr(str):
def __new__(cls, fw_str):
assert fw_str in ["jax", "tensorflow", "torch", "mxnet", "numpy"]
assert fw_str in ["jax", "tensorflow", "torch", "numpy"]
return str.__new__(cls, fw_str)


Expand All @@ -31,19 +30,18 @@ class Framework:
jax.interpreters.xla._DeviceArray,
jaxlib.xla_extension.DeviceArray,
Buffer,
mx.nd.NDArray,
np.ndarray,
Tensor,
torch.Tensor,
]


NativeVariable = Union[
jax.interpreters.xla._DeviceArray, mx.nd.NDArray, np.ndarray, Tensor, torch.Tensor
jax.interpreters.xla._DeviceArray, np.ndarray, Tensor, torch.Tensor
]


NativeDevice = Union[jaxlib.xla_extension.Device, mx.context.Context, str, torch.device]
NativeDevice = Union[jaxlib.xla_extension.Device, str, torch.device]


NativeDtype = Union[jnp.dtype, np.dtype, tf.DType, torch.dtype]
Expand Down Expand Up @@ -399,7 +397,6 @@ class Node(str):
try_import_ivy_jax,
try_import_ivy_tf,
try_import_ivy_torch,
try_import_ivy_mxnet,
try_import_ivy_numpy,
clear_backend_stack,
)
Expand Down
22 changes: 2 additions & 20 deletions ivy/backend_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,18 @@ def __exit__(self, exc_type, exc_val, exc_tb):
_array_types["jaxlib.xla_extension"] = "ivy.functional.backends.jax"
_array_types["tensorflow.python.framework.ops"] = "ivy.functional.backends.tensorflow"
_array_types["torch"] = "ivy.functional.backends.torch"
_array_types["mxnet.ndarray.ndarray"] = "ivy.functional.backends.mxnet"

_backend_dict = dict()
_backend_dict["numpy"] = "ivy.functional.backends.numpy"
_backend_dict["jax"] = "ivy.functional.backends.jax"
_backend_dict["tensorflow"] = "ivy.functional.backends.tensorflow"
_backend_dict["torch"] = "ivy.functional.backends.torch"
_backend_dict["mxnet"] = "ivy.functional.backends.mxnet"

_backend_reverse_dict = dict()
_backend_reverse_dict["ivy.functional.backends.numpy"] = "numpy"
_backend_reverse_dict["ivy.functional.backends.jax"] = "jax"
_backend_reverse_dict["ivy.functional.backends.tensorflow"] = "tensorflow"
_backend_reverse_dict["ivy.functional.backends.torch"] = "torch"
_backend_reverse_dict["ivy.functional.backends.mxnet"] = "mxnet"


# Backend Getting/Setting #
Expand Down Expand Up @@ -284,7 +281,7 @@ def get_backend(backend: Optional[str] = None):
----------
backend
The backend for which we want to retrieve Ivy's backend i.e. one of 'jax',
'torch', 'tensorflow', 'numpy', 'mxnet'.
'torch', 'tensorflow', 'numpy'.
Returns
-------
Expand Down Expand Up @@ -435,20 +432,6 @@ def try_import_ivy_torch(warn=False):
)


def try_import_ivy_mxnet(warn=False):
try:
import ivy.functional.backends.mxnet

return ivy.functional.backends.mxnet
except (ImportError, ModuleNotFoundError) as e:
if not warn:
return
logging.warning(
"{}\n\nmxnet does not appear to be installed, "
"ivy.functional.backends.mxnet can therefore not be imported.\n".format(e)
)


def try_import_ivy_numpy(warn=False):
try:
import ivy.functional.backends.numpy
Expand All @@ -467,15 +450,14 @@ def try_import_ivy_numpy(warn=False):
"jax": try_import_ivy_jax,
"tensorflow": try_import_ivy_tf,
"torch": try_import_ivy_torch,
"mxnet": try_import_ivy_mxnet,
"numpy": try_import_ivy_numpy,
}


def choose_random_backend(excluded=None):
excluded = list() if excluded is None else excluded
while True:
if len(excluded) == 5:
if len(excluded) == 4:
raise Exception(
"Unable to select backend, all backends are either excluded "
"or not installed."
Expand Down
2 changes: 0 additions & 2 deletions ivy/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
"jax": [],
"tensorflow": [],
"torch": [],
"mxnet": ["ndarray"],
}

NATIVE_KEYS_TO_SKIP = {
Expand All @@ -27,7 +26,6 @@
"type",
"requires_grad_",
],
"mxnet": [],
}

# Helpers #
Expand Down
Loading

0 comments on commit e4d5c81

Please sign in to comment.