Skip to content

Commit

Permalink
Merge pull request #345 from tovrstra/format-exceptions
Browse files Browse the repository at this point in the history
Clean up exceptions and split FileFormatError in LoadError and DumpError
  • Loading branch information
tovrstra authored Jun 21, 2024
2 parents 6d50bf0 + cfbbceb commit be722a9
Show file tree
Hide file tree
Showing 22 changed files with 239 additions and 161 deletions.
91 changes: 78 additions & 13 deletions iodata/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@
from typing import Callable, Optional

from .iodata import IOData
from .utils import FileFormatError, LineIterator, PrepareDumpError
from .utils import (
DumpError,
FileFormatError,
LineIterator,
LoadError,
PrepareDumpError,
WriteInputError,
)

__all__ = ["load_one", "load_many", "dump_one", "dump_many", "write_input"]

Expand All @@ -54,7 +61,7 @@ def _select_format_module(filename: str, attrname: str, fmt: Optional[str] = Non
filename
The file to load or dump.
attrname
The required atrtibute of the file format module.
The required attribute of the file format module.
fmt
The name of the file format module to use. When not given, it is guessed
from the filename.
Expand All @@ -63,6 +70,10 @@ def _select_format_module(filename: str, attrname: str, fmt: Optional[str] = Non
-------
The module implementing the required file format.
Raises
------
FileFormatError
When no file format module can be found that has a member named ``attrname``.
"""
basename = os.path.basename(filename)
if fmt is None:
Expand All @@ -73,7 +84,7 @@ def _select_format_module(filename: str, attrname: str, fmt: Optional[str] = Non
return format_module
else:
return FORMAT_MODULES[fmt]
raise ValueError(f"Could not find file format with feature {attrname} for file {filename}")
raise FileFormatError(f"Could not find file format with feature {attrname} for file {filename}")


def _find_input_modules():
Expand Down Expand Up @@ -102,12 +113,17 @@ def _select_input_module(fmt: str) -> ModuleType:
-------
The module implementing the required input format.
Raises
------
FileFormatError
When the format ``fmt`` does not exist.
"""
if fmt in INPUT_MODULES:
if not hasattr(INPUT_MODULES[fmt], "write_input"):
raise ValueError(f"{fmt} input module does not have write_input!")
raise FileFormatError(f"{fmt} input module does not have write_input.")
return INPUT_MODULES[fmt]
raise ValueError(f"Could not find input format {fmt}!")
raise FileFormatError(f"Could not find input format {fmt}.")


def load_one(filename: str, fmt: Optional[str] = None, **kwargs) -> IOData:
Expand Down Expand Up @@ -136,8 +152,12 @@ def load_one(filename: str, fmt: Optional[str] = None, **kwargs) -> IOData:
with LineIterator(filename) as lit:
try:
iodata = IOData(**format_module.load_one(lit, **kwargs))
except LoadError:
raise
except StopIteration:
lit.error("File ended before all data was read.")
except Exception as exc:
raise LoadError(f"{filename}: Uncaught exception while loading file.") from exc
return iodata


Expand Down Expand Up @@ -171,6 +191,10 @@ def load_many(filename: str, fmt: Optional[str] = None, **kwargs) -> Iterator[IO
yield IOData(**data)
except StopIteration:
return
except LoadError:
raise
except Exception as exc:
raise LoadError(f"{filename}: Uncaught exception while loading file.") from exc


def _check_required(iodata: IOData, dump_func: Callable):
Expand Down Expand Up @@ -216,17 +240,33 @@ def dump_one(iodata: IOData, filename: str, fmt: Optional[str] = None, **kwargs)
Raises
------
DumpError
When an error is encountered while dumping to a file.
If the output file already existed, it is (partially) overwritten.
PrepareDumpError
When the iodata object is not compatible with the file format,
e.g. due to missing attributes, and not conversion is available or allowed
to make it compatible.
If the output file already existed, it is not overwritten.
"""
format_module = _select_format_module(filename, "dump_one", fmt)
_check_required(iodata, format_module.dump_one)
if hasattr(format_module, "prepare_dump"):
format_module.prepare_dump(iodata)
try:
_check_required(iodata, format_module.dump_one)
if hasattr(format_module, "prepare_dump"):
format_module.prepare_dump(iodata)
except PrepareDumpError:
raise
except Exception as exc:
raise PrepareDumpError(
f"{filename}: Uncaught exception while preparing for dumping to a file"
) from exc
with open(filename, "w") as f:
format_module.dump_one(f, iodata, **kwargs)
try:
format_module.dump_one(f, iodata, **kwargs)
except DumpError:
raise
except Exception as exc:
raise DumpError(f"{filename}: Uncaught exception while dumping to a file") from exc


def dump_many(iodatas: Iterable[IOData], filename: str, fmt: Optional[str] = None, **kwargs):
Expand All @@ -249,10 +289,16 @@ def dump_many(iodatas: Iterable[IOData], filename: str, fmt: Optional[str] = Non
Raises
------
DumpError
When an error is encountered while dumping to a file.
If the output file already existed, it (partially) overwritten.
PrepareDumpError
When the iodata object is not compatible with the file format,
e.g. due to missing attributes, and not conversion is available or allowed
to make it compatible.
If the output file already existed, it is not overwritten when this error
is raised while processing the first IOData instance in the ``iodatas`` argument.
When the exception is raised in later iterations, any existing file is overwritten.
"""
format_module = _select_format_module(filename, "dump_many", fmt)

Expand All @@ -262,9 +308,18 @@ def dump_many(iodatas: Iterable[IOData], filename: str, fmt: Optional[str] = Non
iter_iodatas = iter(iodatas)
try:
first = next(iter_iodatas)
_check_required(first, format_module.dump_many)
except StopIteration as exc:
raise FileFormatError("dump_many needs at least one iodata object.") from exc
raise DumpError(f"{filename}: dump_many needs at least one iodata object.") from exc
try:
_check_required(first, format_module.dump_many)
if hasattr(format_module, "prepare_dump"):
format_module.prepare_dump(first)
except PrepareDumpError:
raise
except Exception as exc:
raise PrepareDumpError(
f"{filename}: Uncaught exception while preparing for dumping to a file"
) from exc

def checking_iterator():
"""Iterate over all iodata items, not checking the first."""
Expand All @@ -277,7 +332,12 @@ def checking_iterator():
yield other

with open(filename, "w") as f:
format_module.dump_many(f, checking_iterator(), **kwargs)
try:
format_module.dump_many(f, checking_iterator(), **kwargs)
except (PrepareDumpError, DumpError):
raise
except Exception as exc:
raise DumpError(f"{filename}: Uncaught exception while dumping to a file") from exc


def write_input(
Expand Down Expand Up @@ -312,4 +372,9 @@ def write_input(
"""
input_module = _select_input_module(fmt)
with open(filename, "w") as fh:
input_module.write_input(fh, iodata, template, atom_line, **kwargs)
try:
input_module.write_input(fh, iodata, template, atom_line, **kwargs)
except Exception as exc:
raise WriteInputError(
f"{filename}: Uncaught exception while writing an input file"
) from exc
6 changes: 3 additions & 3 deletions iodata/formats/fchk.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ..docstrings import document_dump_one, document_load_many, document_load_one
from ..iodata import IOData
from ..orbitals import MolecularOrbitals
from ..utils import LineIterator, PrepareDumpError, amu
from ..utils import DumpError, LineIterator, PrepareDumpError, amu

__all__ = []

Expand Down Expand Up @@ -221,7 +221,7 @@ def load_one(lit: LineIterator) -> dict:
if nalpha < 0 or nbeta < 0 or nalpha + nbeta <= 0:
lit.error("The number of electrons is not positive.")
if nalpha < nbeta:
raise ValueError(f"n_alpha={nalpha} < n_beta={nbeta} is not valid!")
lit.error(f"n_alpha={nalpha} < n_beta={nbeta} is not valid!")

norba = fchk["Alpha Orbital Energies"].shape[0]
mo_coeffs = np.copy(fchk["Alpha MO coefficients"].reshape(norba, nbasis).T)
Expand Down Expand Up @@ -643,7 +643,7 @@ def dump_one(f: TextIO, data: IOData):
elif shell.ncon == 2 and shell.angmoms == [0, 1]:
shell_types.append(-1)
else:
raise ValueError("Cannot identify type of shell!")
raise DumpError("Cannot identify type of shell!")

num_pure_d_shells = sum([1 for st in shell_types if st == 2])
num_pure_f_shells = sum([1 for st in shell_types if st == 3])
Expand Down
2 changes: 1 addition & 1 deletion iodata/formats/gamess.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _read_data(lit: LineIterator) -> tuple[str, str, list[str]]:
# The dat file only contains symmetry-unique atoms, so we would be incapable of
# supporting non-C1 symmetry without significant additional coding.
if symmetry != "C1":
raise NotImplementedError(f"Only C1 symmetry is supported. Got {symmetry}")
lit.error(f"Only C1 symmetry is supported. Got {symmetry}")
symbols = []
line = True
while line != " $END \n":
Expand Down
2 changes: 1 addition & 1 deletion iodata/formats/gaussianinput.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def load_one(lit: LineIterator):
if not contents:
break
if len(contents) != 4:
raise ValueError("No Cartesian Structure is detected")
lit.error("No Cartesian Structure is detected")
numbers.append(sym2num[contents[0]])
coor = list(map(float, contents[1:]))
coordinates.append(coor)
Expand Down
7 changes: 2 additions & 5 deletions iodata/formats/gromacs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,7 @@
def load_one(lit: LineIterator) -> dict:
"""Do not edit this docstring. It will be overwritten."""
while True:
try:
data = _helper_read_frame(lit)
except StopIteration:
break
data = _helper_read_frame(lit)
title = data[0]
time = data[1]
resnums = np.array(data[2])
Expand Down Expand Up @@ -75,7 +72,7 @@ def load_many(lit: LineIterator) -> Iterator[dict]:
try:
while True:
yield load_one(lit)
except OSError:
except StopIteration:
return


Expand Down
Loading

0 comments on commit be722a9

Please sign in to comment.