Skip to content

Commit

Permalink
Removed importing all frameworks if they're installed on importing iv…
Browse files Browse the repository at this point in the history
…y in the __init__.py and stateful/converters.py (#10720)
  • Loading branch information
vedpatwardhan authored Feb 21, 2023
1 parent ceda89e commit fa78825
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 65 deletions.
36 changes: 2 additions & 34 deletions ivy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,10 @@
# global
import copy
from types import SimpleNamespace
import warnings
from ivy._version import __version__ as __version__
import builtins
import numpy as np

try:
import torch
except ImportError:
torch = SimpleNamespace()
torch.Size = SimpleNamespace()
torch.Tensor = SimpleNamespace()

try:
import tensorflow as tf
except ImportError:
tf = SimpleNamespace()
tf.TensorShape = SimpleNamespace()
tf.Tensor = SimpleNamespace()

try:
import jax
import jaxlib
except ImportError:
jax = SimpleNamespace()
jax.interpreters = SimpleNamespace()
jax.interpreters.xla = SimpleNamespace()
jax.interpreters.xla._DeviceArray = SimpleNamespace()
jaxlib = SimpleNamespace()
jaxlib.xla_extension = SimpleNamespace()
jaxlib.xla_extension.DeviceArray = SimpleNamespace()
jaxlib.xla_extension.Buffer = SimpleNamespace()

warnings.filterwarnings("ignore", module="^(?!.*ivy).*$")

Expand Down Expand Up @@ -215,13 +188,8 @@ def __new__(cls, shape_tup):
valid_types += (ivy.NativeShape, ivy.NativeArray)
else:
valid_types += (
tf.TensorShape,
torch.Size,
jax.interpreters.xla._DeviceArray,
jaxlib.xla_extension.DeviceArray,
jax.xla_extension.Buffer,
np.ndarray,
tf.Tensor,
current_backend(shape_tup).NativeShape,
current_backend(shape_tup).NativeArray,
)
ivy.utils.assertions.check_isinstance(shape_tup, valid_types)
if isinstance(shape_tup, int):
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/ivy/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
# -------------------#


@handle_array_like_without_promotion
@to_native_arrays_and_back
@handle_out_argument
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
@handle_array_function
def all(
x: Union[ivy.Array, ivy.NativeArray],
Expand Down
37 changes: 7 additions & 30 deletions ivy/stateful/converters.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,10 @@
"""Converters from Native Modules to Ivy Modules"""
# global
from types import SimpleNamespace
from typing import Optional, Dict, List
import re # noqa
import inspect
from collections import OrderedDict

try:
import haiku as hk
from haiku._src.data_structures import FlatMapping
import jax
except ImportError:
hk = SimpleNamespace()
hk.Module = SimpleNamespace
hk.transform = SimpleNamespace
hk.get_parameter = SimpleNamespace
FlatMapping = SimpleNamespace
jax = SimpleNamespace()
jax.random = SimpleNamespace()
jax.random.PRNGKey = SimpleNamespace

try:
import torch
except ImportError:
torch = SimpleNamespace()
torch.nn = SimpleNamespace()
torch.nn.Parameter = SimpleNamespace
torch.nn.Module = SimpleNamespace

try:
import tensorflow as tf
except ImportError:
tf = SimpleNamespace()
tf.keras = SimpleNamespace()
tf.keras.Model = SimpleNamespace
import importlib

# local
import ivy
Expand Down Expand Up @@ -142,6 +113,10 @@ def from_haiku_module(
The new trainable torch module instance.
"""
if not importlib.util.find_spec("haiku"):
import haiku as hk
if not importlib.util.find_spec("FlatMapping", "haiku._src.data_structures"):
from haiku._src.data_structures import FlatMapping

def _hk_flat_map_to_dict(hk_flat_map):
ret_dict = dict()
Expand Down Expand Up @@ -369,6 +344,8 @@ def from_torch_module(
ret
The new trainable ivy.Module instance.
"""
if not importlib.util.find_spec("torch"):
import torch

class TorchIvyModule(ivy.Module):
def __init__(
Expand Down

0 comments on commit fa78825

Please sign in to comment.