Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Feb 1, 2025
1 parent 7093a26 commit 43109eb
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions brainunit/_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import jax.numpy as jnp
import numpy as np
import pytest
from numpy.ma.core import indices
from numpy.testing import assert_equal

import brainunit as u
Expand Down Expand Up @@ -768,23 +769,23 @@ def test_deepcopy(self):
assert d_copy["x"] == 2 * second
assert d["x"] == 1 * second

def test_numpy_functions_indices(self):
def test_indices_functions(self):
"""
Check numpy functions that return indices.
"""
values = [np.array([-4, 3, -2, 1, 0]), np.ones((3, 3)), np.array([17])]
units = [volt, second, siemens, mV, kHz]

# numpy functions
keep_dim_funcs = [np.argmin, np.argmax, np.argsort, np.nonzero]
indice_funcs = [u.math.argmin, u.math.argmax, u.math.argsort, u.math.nonzero]

for value, unit in itertools.product(values, units):
q_ar = value * unit
for func in keep_dim_funcs:
for func in indice_funcs:
test_ar = func(q_ar)
# Compare it to the result on the same value without units
comparison_ar = func(value)
test_ar = np.asarray(test_ar)
test_ar = u.math.asarray(test_ar)
comparison_ar = np.asarray(comparison_ar)
assert_equal(
test_ar,
Expand Down

0 comments on commit 43109eb

Please sign in to comment.