Skip to content

Commit

Permalink
Merge pull request #2 from xgcm/test-with-dask
Browse files Browse the repository at this point in the history
Test with dask
  • Loading branch information
rabernat authored Oct 7, 2020
2 parents c828391 + 2cbccf2 commit 4301736
Showing 1 changed file with 45 additions and 11 deletions.
56 changes: 45 additions & 11 deletions fastjmd95/test/test_jmd95.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from fastjmd95 import rho, drhodt, drhods
from .reference_values import rho_expected, drhodt_expected, drhods_expected

import dask
import dask.array

@pytest.fixture
def s_t_p():
Expand All @@ -14,20 +16,52 @@ def s_t_p():
p, t, s = np.array(list(product(p0, t0, s0))).transpose()
return s, t, p

def _chunk(*args):
return [dask.array.from_array(a, chunks=(100,)) for a in args]

def test_rho(s_t_p):
s, t, p = s_t_p
rho_actual = rho(s, t, p)
np.testing.assert_allclose(rho_actual, rho_expected)

@pytest.fixture
def no_client():
return None

def test_drhot(s_t_p):
s, t, p = s_t_p
drhodt_actual = drhodt(s, t, p)
np.testing.assert_allclose(drhodt_actual, drhodt_expected, rtol=1e-2)

@pytest.fixture
def threaded_client():
with dask.config.set(scheduler='threads'):
yield


@pytest.fixture
def processes_client():
with dask.config.set(scheduler='processes'):
yield


@pytest.fixture(scope='module')
def distributed_client():
from dask.distributed import Client, LocalCluster
cluster = LocalCluster(threads_per_worker=1,
n_workers=2,
processes=True)
client = Client(cluster)
yield
client.close()
del client
cluster.close()
del cluster


def test_drhos(s_t_p):
all_clients = ['no_client', 'threaded_client', 'processes_client', 'distributed_client']
# https://stackoverflow.com/questions/45225950/passing-yield-fixtures-as-test-parameters-with-a-temp-directory
@pytest.mark.parametrize('client', all_clients)
@pytest.mark.parametrize('function,expected',
[(rho, rho_expected),
(drhodt, drhodt_expected),
(drhods, drhods_expected)])
def test_functions(request, client, s_t_p, function, expected):
s, t, p = s_t_p
drhods_actual = drhods(s, t, p)
np.testing.assert_allclose(drhods_actual, drhods_expected, rtol=1e-2)
if client != 'no_client':
s, t, p = _chunk(s, t, p)
client = request.getfixturevalue(client)
actual = function(s, t, p)
np.testing.assert_allclose(actual, expected, rtol=1e-2)

0 comments on commit 4301736

Please sign in to comment.