Skip to content

Commit

Permalink
Merge pull request #118 from LSSTDESC/doc-updates
Browse files Browse the repository at this point in the history
Move methods from TXPipe base stage and update documentation
  • Loading branch information
joezuntz authored Feb 5, 2025
2 parents 5800243 + cade508 commit 2c3b000
Show file tree
Hide file tree
Showing 2 changed files with 267 additions and 15 deletions.
243 changes: 228 additions & 15 deletions ceci/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pdb
import datetime
import warnings
import socket

from abc import abstractmethod
from . import errors
Expand Down Expand Up @@ -147,7 +148,10 @@ def get_aliased_tag(self, tag):

@abstractmethod
def run(self): # pragma: no cover
"""Run the stage and return the execution status"""
"""Run the stage and return the execution status.
Subclasses must implemented this method.
"""
raise NotImplementedError("run")

def validate(self):
Expand Down Expand Up @@ -348,7 +352,8 @@ def __init_subclass__(cls, **kwargs):
path = pathlib.Path(filename).resolve()

# Add a description of the parameters to the end of the docstring
if stage_is_complete:
# If no config options are specified, omit this.
if stage_is_complete and cls.config_options:
config_text = cls._describe_configuration_text()
if cls.__doc__ is None:
cls.__doc__ = f"Stage {cls.name}\n\nConfiguration Parameters:\n{config_text}"
Expand Down Expand Up @@ -810,13 +815,24 @@ def is_parallel(self):

def is_mpi(self):
"""
Returns True if the stage is being run under MPI.
Check if the stage is being run under MPI.
Returns
-------
bool
True if the stage is being run under MPI
"""
return self._parallel == MPI_PARALLEL

def is_dask(self):
"""
Returns True if the stage is being run in parallel with Dask.
Check if the stage is being run in parallel with Dask.
Returns
-------
bool
True if the stage is being run under MPI
"""
return self._parallel == DASK_PARALLEL

Expand Down Expand Up @@ -967,6 +983,11 @@ def data_ranges_by_rank(self, n_rows, chunk_rows, parallel=True):
Parallel: bool
Whether to split data by rank or just give all procs all data.
Default=True
Returns
-------
start, end: tuple
The start and end of the range of rows to be read by this process
"""
n_chunks = n_rows // chunk_rows
if n_chunks * chunk_rows < n_rows: # pragma: no cover
Expand All @@ -988,6 +1009,17 @@ def get_input(self, tag):
"""
Return the path of an input file with the given tag,
which can be aliased.
Parameters
----------
tag: str
Tag as listed in self.outputs
Returns
-------
path: str
The path to the output file
"""
tag = self.get_aliased_tag(tag)
return self._inputs[tag]
Expand All @@ -1000,7 +1032,21 @@ def get_output(self, tag, final_name=False):
which can be aliased already.
If final_name is False then use a temporary name - file will
be moved to its final name at the end
be moved to its final name at the end. The temporary name
is prefixed with `inprogress_`.
Parameters
----------
tag: str
Tag as listed in self.outputs
final_name: bool
Default=False. Whether to save to the final name.
Returns
-------
path: str
The path to the output file
"""

tag = self.get_aliased_tag(tag)
Expand All @@ -1023,6 +1069,21 @@ def open_input(self, tag, wrapper=False, **kwargs):
For specialized file types like FITS or HDF5 it will return
a more specific object - see the types.py file for more info.
Parameters
----------
tag: str
Tag as listed in self.inputs
wrapper: bool
Whether to return an underlying file object (False) or a data type instance (True)
**kwargs: dict
Extra arguments to pass to the file class constructor
Returns
-------
obj: file or object
The opened file or object
"""
path = self.get_input(tag)
input_class = self.get_input_type(tag)
Expand All @@ -1039,7 +1100,7 @@ def open_output(
Find and open an output file with the given tag, in write mode.
If final_name is True then they will be opened using their final
target output name. Otherwise we will prepend "inprogress_" to their
target output name. Otherwise we will prepend `inprogress_` to their
file name. This means we know that if the final file exists then it
is completed.
Expand All @@ -1050,19 +1111,22 @@ def open_output(
Parameters
----------
tag: str
Tag as listed in self.outputs
wrapper: bool
Default=False. Whether to return a wrapped file
Whether to return an underlying file object (False) or a data type instance (True)
final_name: bool
Default=False. Whether to save to
**kwargs:
Extra args are passed on to the file's class constructor.
Returns
-------
obj: file or object
The opened file or object
"""
path = self.get_output(tag, final_name=final_name)
output_class = self.get_output_type(tag)
Expand Down Expand Up @@ -1107,33 +1171,63 @@ def open_output(
@classmethod
def inputs_(cls):
"""
Return the dict of inputs
Return the dict mapping input tags to file names.
Returns
-------
in_dict : dict[str:str]
"""
return cls.inputs # pylint: disable=no-member

@classmethod
def outputs_(cls):
"""
Return the dict of inputs
Return the dict mapping output tags to file names.
Returns
-------
out_dict : dict[str:str]
"""
return cls.outputs # pylint: disable=no-member

@classmethod
def output_tags(cls):
"""
Return the list of output tags required by this stage
Return the list of output tags required by this stage.
Returns
-------
out_tags : list[str]
The list of output tags
"""
return [tag for tag, _ in cls.outputs_()]

@classmethod
def input_tags(cls):
"""
Return the list of input tags required by this stage
Return the list of input tags required by this stage.
Returns
-------
in_tags : list[str]
The list of input tags
"""
return [tag for tag, _ in cls.inputs_()]

def get_input_type(self, tag):
"""Return the file type class of an input file with the given tag."""
"""
Return the file type class of an input file with the given tag.
Parameters
----------
tag : str
The tag of the input file
Returns
-------
ftype : FileType
The file type class
"""
tag = self.get_aliased_tag(tag)
for t, dt in self.inputs_():
t = self.get_aliased_tag(t)
Expand All @@ -1142,7 +1236,19 @@ def get_input_type(self, tag):
raise ValueError(f"Tag {tag} is not a known input") # pragma: no cover

def get_output_type(self, tag):
"""Return the file type class of an output file with the given tag."""
"""
Return the file type class of an output file with the given tag.
Parameters
----------
tag : str
The tag of the output file
Returns
-------
ftype : FileType
The file type class
"""
tag = self.get_aliased_tag(tag)
for t, dt in self.outputs_():
t = self.get_aliased_tag(t)
Expand All @@ -1162,8 +1268,12 @@ def instance_name(self):
@property
def config(self):
"""
Returns the configuration dictionary for this stage, aggregating command
The configuration dictionary for this stage, aggregating command
line options and optional configuration file.
Options specified in the subclass variable `config_options` are
read from the configuration file, command line, or `make_stage` choices,
and stored in this dictionary.
"""
return self._configs

Expand Down Expand Up @@ -1292,6 +1402,7 @@ def iterate_fits(
Loop through chunks of the input data from a FITS file with the given tag
TODO: add ceci tests of this functions
Parameters
----------
tag: str
Expand Down Expand Up @@ -1380,6 +1491,59 @@ def iterate_hdf(
data = {col: group[col][start:end] for col in cols}
yield start, end, data

def combined_iterators(self, rows, *inputs, parallel=True):
"""
Iterate through multiple files at the same time.
If you have more several HDF files with the some
columns of the same length then you can use this method to
iterate through them all at once, and combine the data from
all of them into a single dictionary.
Parameters
----------
rows: int
The number of rows to read in each chunk
*inputs: list
A list of (tag, group, cols) triples for each file to read.
In each case tag is the input file name tag, group is the
group within the HDF5 file to read, and cols is a list of
columns to read from that group. Specify multiple triplets
to read from multiple files
parallel: bool
Whether to split up data among processes (parallel=True) or give
all processes all data (parallel=False). Default = True.
Returns
-------
it: iterator
Iterator yielding (int, int, dict) tuples of (start, end, data)
"""
if not len(inputs) % 3 == 0:
raise ValueError(
"Arguments to combined_iterators should be in threes: "
"tag, group, value"
)
n = len(inputs) // 3

iterators = []
for i in range(n):
tag = inputs[3 * i]
section = inputs[3 * i + 1]
cols = inputs[3 * i + 2]
iterators.append(
self.iterate_hdf(tag, section, cols, rows, parallel=parallel)
)

for it in zip(*iterators):
data = {}
for (s, e, d) in it:
data.update(d)
yield s, e, data


################################
# Pipeline-related methods
################################
Expand Down Expand Up @@ -1579,3 +1743,52 @@ def generate_cwl(cls, log_dir=None):
# cwl_tool.metadata = cwlgen.Metadata(**metadata)

return cwl_tool


def time_stamp(self, tag):
"""
Print a time stamp with an optional tag.
Parameters
----------
tag: str
Additional info to print in the output line. Default is empty.
"""
t = datetime.datetime.now()
print(f"Process {self.rank}: {tag} {t}")
sys.stdout.flush()

def memory_report(self, tag=None):
"""
Print a report about memory currently available
on the node the process is running on.
Parameters
----------
tag: str
Additional info to print in the output line. Default is empty.
"""
import psutil

t = datetime.datetime.now()

# The different types of memory are really fiddly and don't
# correspond to how you usually imagine. The simplest thing
# to report here is just how much memory is left on the machine.
mem = psutil.virtual_memory()
avail = mem.available / 1024**3
total = mem.total / 1024**3

if tag is None:
tag = ""
else:
tag = f" {tag}:"

# This gives you the name of the host. At NERSC that is the node name
host = socket.gethostname()

# Print messsage
print(
f"{t}: Process {self.rank}:{tag} Remaining memory on {host} {avail:.1f} GB / {total:.1f} GB"
)
sys.stdout.flush()
Loading

0 comments on commit 2c3b000

Please sign in to comment.