Skip to content

Commit

Permalink
Fixed device fixture issue and concat axis argument in test_ivy_demos (
Browse files Browse the repository at this point in the history
  • Loading branch information
vedpatwardhan authored Jan 30, 2023
1 parent 6d09694 commit f8cae0f
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions ivy_tests/test_ivy/test_misc/test_ivy_demos.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# ------#

# training
def test_training_demo(device):
def test_training_demo(on_device):

if ivy.current_backend_str() == "numpy":
# numpy does not support gradients
Expand Down Expand Up @@ -43,17 +43,17 @@ def loss_fn(v):


# functional api
def test_array(device):
def test_array(on_device):
ivy.unset_backend()
import jax.numpy as jnp

assert ivy.concat((jnp.ones((1,)), jnp.ones((1,))), -1).shape == (2,)
assert ivy.concat((jnp.ones((1,)), jnp.ones((1,))), axis=-1).shape == (2,)
import tensorflow as tf

assert ivy.concat((tf.ones((1,)), tf.ones((1,))), -1).shape == (2,)
assert ivy.concat((tf.ones((1,)), tf.ones((1,))), axis=-1).shape == (2,)
import numpy as np

assert ivy.concat((np.ones((1,)), np.ones((1,))), -1).shape == (2,)
assert ivy.concat((np.ones((1,)), np.ones((1,))), axis=-1).shape == (2,)
import torch

assert ivy.concat((torch.ones((1,)), torch.ones((1,))), -1).shape == (2,)
assert ivy.concat((torch.ones((1,)), torch.ones((1,))), axis=-1).shape == (2,)

0 comments on commit f8cae0f

Please sign in to comment.