Skip to content

Commit

Permalink
add parallel=False mode to iterate_hdf
Browse files Browse the repository at this point in the history
  • Loading branch information
joezuntz committed Mar 8, 2021
1 parent b5d386f commit 3982fd9
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:

- name: Install
run: |
pip install --upgrade pytest codecov pytest-cov
pip install --upgrade pytest codecov pytest-cov h5py
pip install .[test,cwl,parsl]
- name: Tests
Expand Down
18 changes: 11 additions & 7 deletions ceci/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,13 +448,13 @@ def split_tasks_by_rank(self, tasks):
----------
tasks: iterable
Tasks to split up
"""
for i, task in enumerate(tasks):
if i % self.size == self.rank:
yield task

def data_ranges_by_rank(self, n_rows, chunk_rows):
def data_ranges_by_rank(self, n_rows, chunk_rows, parallel=True):
"""Split a number of rows by process.
Given a total number of rows to read and a chunk size, yield
Expand All @@ -471,7 +471,11 @@ def data_ranges_by_rank(self, n_rows, chunk_rows):
n_chunks = n_rows // chunk_rows
if n_chunks * chunk_rows < n_rows:
n_chunks += 1
for i in self.split_tasks_by_rank(range(n_chunks)):
if parallel:
it = self.split_tasks_by_rank(range(n_chunks))
else:
it = range(n_chunks)
for i in it:
start = i * chunk_rows
end = min((i + 1) * chunk_rows, n_rows)
yield start, end
Expand Down Expand Up @@ -721,7 +725,7 @@ def read_config(self, args):

return my_config

def iterate_fits(self, tag, hdunum, cols, chunk_rows):
def iterate_fits(self, tag, hdunum, cols, chunk_rows, parallel=True):
"""
Loop through chunks of the input data from a FITS file with the given tag
Expand All @@ -730,11 +734,11 @@ def iterate_fits(self, tag, hdunum, cols, chunk_rows):
fits = self.open_input(tag)
ext = fits[hdunum]
n = ext.get_nrows()
for start, end in self.data_ranges_by_rank(n, chunk_rows):
for start, end in self.data_ranges_by_rank(n, chunk_rows, parallel=True):
data = ext.read_columns(cols, rows=range(start, end))
yield start, end, data

def iterate_hdf(self, tag, group_name, cols, chunk_rows):
def iterate_hdf(self, tag, group_name, cols, chunk_rows, parallel=True):
"""
Loop through chunks of the input data from an HDF5 file with the given tag.
Expand All @@ -757,7 +761,7 @@ def iterate_hdf(self, tag, group_name, cols, chunk_rows):
)

# Iterate through the data providing chunks
for start, end in self.data_ranges_by_rank(n, chunk_rows):
for start, end in self.data_ranges_by_rank(n, chunk_rows, parallel=parallel):
data = {col: group[col][start:end] for col in cols}
yield start, end, data

Expand Down
9 changes: 9 additions & 0 deletions ceci_example/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ class DataFile:
"""

def __init__(self, path, mode, extra_provenance=None, validate=True, **kwargs):
self.path = path
self.mode = mode

if mode not in ["r", "w"]:
raise ValueError(f"File 'mode' argument must be 'r' or 'w' not '{mode}'")

self.file = self.open(path, mode, **kwargs)

@classmethod
def open(cls, path, mode):
"""
Expand Down
27 changes: 24 additions & 3 deletions tests/test_stage.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from ceci.stage import PipelineStage
from ceci_example.types import HDFFile
import pytest
import numpy as np

# TODO: test MPI facilities properly with:
# https://github.com/rmjarvis/TreeCorr/blob/releases/4.1/tests/mock_mpi.py
Expand All @@ -26,14 +28,14 @@ class TestStage(PipelineStage):
# This one should work
class TestStage(PipelineStage):
name = "test"
inputs = []
inputs = [("inp1", HDFFile)]
outputs = []
config = {}

assert PipelineStage.get_stage("test") == TestStage
assert TestStage.get_module().endswith("test_stage")

s = TestStage({"config": "tests/config.yml"})
s = TestStage({"config": "tests/config.yml", "inp1": "tests/test.hdf5"})

assert s.rank == 0
assert s.size == 1
Expand All @@ -45,9 +47,15 @@ class TestStage(PipelineStage):
assert r[0] == (0, 100)
assert r[2] == (200, 300)

r = list(s.iterate_hdf("inp1", "group1", ["x", "y", "z"], 10))
for ri in r:
s, e, ri = ri
assert len(ri["x"] == 10)
assert np.all(r[4][2]["z"] == [-80, -82, -84, -86, -88, -90, -92, -94, -96, -98])

# Fake that we are processor 4/10
comm = MockCommunicator(10, 4)
s = TestStage({"config": "tests/config.yml"}, comm=comm)
s = TestStage({"config": "tests/config.yml", "inp1": "tests/test.hdf5"}, comm=comm)

assert s.rank == 4
assert s.size == 10
Expand All @@ -59,6 +67,19 @@ class TestStage(PipelineStage):
assert r[0] == (400, 500)
assert r[3] == (3400, 3500)

r = list(s.iterate_hdf("inp1", "group1", ["x", "y", "z"], 10))
assert len(r) == 1
st, e, r = r[0]
assert st == 40
assert e == 50
assert np.all(r["x"] == range(40, 50))

r = list(s.iterate_hdf("inp1", "group1", ["x", "y", "z"], 10, parallel=False))
for ri in r:
s, e, ri = ri
assert len(ri["x"] == 10)
assert np.all(r[4][2]["z"] == [-80, -82, -84, -86, -88, -90, -92, -94, -96, -98])

# I'd rather not attempt to unit test MPI stuff - that sounds very unreliable


Expand Down

0 comments on commit 3982fd9

Please sign in to comment.