Skip to content

Commit

Permalink
fix (stateful)(utilities): adding tf.keras.layers.Layer as a suppor…
Browse files Browse the repository at this point in the history
…ted module for TF
  • Loading branch information
YushaArif99 committed Sep 19, 2024
1 parent b442206 commit 7e8d08f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions ivy/stateful/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
def _is_submodule(obj, kw):
cls_str = {
"torch": ("torch.nn.modules.module.Module",),
"keras": ("keras.engine.training.Model", "keras.src.models.model.Model"),
"keras": ("keras.engine.training.Model", "keras.src.models.model.Model", "keras.src.layers.layer.Layer"),
"flax": ("flax.nnx.nnx.module.Module",),
}[kw]
try:
Expand Down Expand Up @@ -578,7 +578,7 @@ def _compute_module_dict_tf(model, prefix=""):

try:
assert isinstance(
model_tf, tf.keras.Model
model_tf, (tf.keras.Model,tf.keras.layers.Layer)
), "The second model must be an instance of `tf.keras.Model` (TensorFlow)."
except AssertionError as e:
raise TypeError("The second model must be a TensorFlow model.") from e
Expand Down

0 comments on commit 7e8d08f

Please sign in to comment.