Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up exceptions and split FileFormatError in LoadError and DumpError #345

Merged
merged 3 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 78 additions & 13 deletions iodata/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@
from typing import Callable, Optional

from .iodata import IOData
from .utils import FileFormatError, LineIterator, PrepareDumpError
from .utils import (
DumpError,
FileFormatError,
LineIterator,
LoadError,
PrepareDumpError,
WriteInputError,
)

__all__ = ["load_one", "load_many", "dump_one", "dump_many", "write_input"]

Expand All @@ -54,7 +61,7 @@ def _select_format_module(filename: str, attrname: str, fmt: Optional[str] = Non
filename
The file to load or dump.
attrname
tovrstra marked this conversation as resolved.
Show resolved Hide resolved
The required atrtibute of the file format module.
The required attribute of the file format module.
fmt
The name of the file format module to use. When not given, it is guessed
from the filename.
Expand All @@ -63,6 +70,10 @@ def _select_format_module(filename: str, attrname: str, fmt: Optional[str] = Non
-------
The module implementing the required file format.

Raises
------
FileFormatError
When no file format module can be found that has a member named ``attrname``.
"""
basename = os.path.basename(filename)
if fmt is None:
Expand All @@ -73,7 +84,7 @@ def _select_format_module(filename: str, attrname: str, fmt: Optional[str] = Non
return format_module
else:
return FORMAT_MODULES[fmt]
raise ValueError(f"Could not find file format with feature {attrname} for file {filename}")
raise FileFormatError(f"Could not find file format with feature {attrname} for file {filename}")


def _find_input_modules():
Expand Down Expand Up @@ -102,12 +113,17 @@ def _select_input_module(fmt: str) -> ModuleType:
-------
The module implementing the required input format.


Raises
------
FileFormatError
When the format ``fmt`` does not exist.
"""
if fmt in INPUT_MODULES:
if not hasattr(INPUT_MODULES[fmt], "write_input"):
raise ValueError(f"{fmt} input module does not have write_input!")
raise FileFormatError(f"{fmt} input module does not have write_input.")
return INPUT_MODULES[fmt]
raise ValueError(f"Could not find input format {fmt}!")
raise FileFormatError(f"Could not find input format {fmt}.")


def load_one(filename: str, fmt: Optional[str] = None, **kwargs) -> IOData:
Expand Down Expand Up @@ -136,8 +152,12 @@ def load_one(filename: str, fmt: Optional[str] = None, **kwargs) -> IOData:
with LineIterator(filename) as lit:
try:
iodata = IOData(**format_module.load_one(lit, **kwargs))
except LoadError:
raise
except StopIteration:
lit.error("File ended before all data was read.")
except Exception as exc:
raise LoadError(f"{filename}: Uncaught exception while loading file.") from exc
return iodata


Expand Down Expand Up @@ -171,6 +191,10 @@ def load_many(filename: str, fmt: Optional[str] = None, **kwargs) -> Iterator[IO
yield IOData(**data)
except StopIteration:
return
except LoadError:
raise
except Exception as exc:
raise LoadError(f"{filename}: Uncaught exception while loading file.") from exc


def _check_required(iodata: IOData, dump_func: Callable):
Expand Down Expand Up @@ -216,17 +240,33 @@ def dump_one(iodata: IOData, filename: str, fmt: Optional[str] = None, **kwargs)

Raises
------
DumpError
When an error is encountered while dumping to a file.
If the output file already existed, it (partially) overwritten.
PrepareDumpError
When the iodata object is not compatible with the file format,
e.g. due to missing attributes, and not conversion is available or allowed
to make it compatible.
tovrstra marked this conversation as resolved.
Show resolved Hide resolved
If the output file already existed, it is not overwritten.
"""
format_module = _select_format_module(filename, "dump_one", fmt)
_check_required(iodata, format_module.dump_one)
if hasattr(format_module, "prepare_dump"):
format_module.prepare_dump(iodata)
try:
_check_required(iodata, format_module.dump_one)
if hasattr(format_module, "prepare_dump"):
format_module.prepare_dump(iodata)
except PrepareDumpError:
raise
except Exception as exc:
raise PrepareDumpError(
f"{filename}: Uncaught exception while preparing for dumping to a file"
) from exc
with open(filename, "w") as f:
format_module.dump_one(f, iodata, **kwargs)
try:
format_module.dump_one(f, iodata, **kwargs)
except DumpError:
raise
except Exception as exc:
raise DumpError(f"{filename}: Uncaught exception while dumping to a file") from exc


def dump_many(iodatas: Iterable[IOData], filename: str, fmt: Optional[str] = None, **kwargs):
Expand All @@ -249,10 +289,16 @@ def dump_many(iodatas: Iterable[IOData], filename: str, fmt: Optional[str] = Non

Raises
------
DumpError
When an error is encountered while dumping to a file.
If the output file already existed, it (partially) overwritten.
PrepareDumpError
When the iodata object is not compatible with the file format,
e.g. due to missing attributes, and not conversion is available or allowed
to make it compatible.
If the output file already existed, it is not overwritten when this error
is raised while processing the first IOData instance in the ``iodatas`` argument.
When the exception is raised in later iterations, any existing file is overwritten.
"""
format_module = _select_format_module(filename, "dump_many", fmt)

Expand All @@ -262,9 +308,18 @@ def dump_many(iodatas: Iterable[IOData], filename: str, fmt: Optional[str] = Non
iter_iodatas = iter(iodatas)
try:
first = next(iter_iodatas)
_check_required(first, format_module.dump_many)
except StopIteration as exc:
raise FileFormatError("dump_many needs at least one iodata object.") from exc
raise DumpError(f"{filename}: dump_many needs at least one iodata object.") from exc
try:
_check_required(first, format_module.dump_many)
if hasattr(format_module, "prepare_dump"):
format_module.prepare_dump(first)
except PrepareDumpError:
raise
except Exception as exc:
raise PrepareDumpError(
f"{filename}: Uncaught exception while preparing for dumping to a file"
) from exc

def checking_iterator():
"""Iterate over all iodata items, not checking the first."""
Expand All @@ -277,7 +332,12 @@ def checking_iterator():
yield other

with open(filename, "w") as f:
format_module.dump_many(f, checking_iterator(), **kwargs)
try:
format_module.dump_many(f, checking_iterator(), **kwargs)
except (PrepareDumpError, DumpError):
raise
except Exception as exc:
raise DumpError(f"{filename}: Uncaught exception while dumping to a file") from exc


def write_input(
Expand Down Expand Up @@ -312,4 +372,9 @@ def write_input(
"""
input_module = _select_input_module(fmt)
with open(filename, "w") as fh:
input_module.write_input(fh, iodata, template, atom_line, **kwargs)
try:
input_module.write_input(fh, iodata, template, atom_line, **kwargs)
except Exception as exc:
raise WriteInputError(
f"{filename}: Uncaught exception while writing an input file"
) from exc
6 changes: 3 additions & 3 deletions iodata/formats/fchk.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ..docstrings import document_dump_one, document_load_many, document_load_one
from ..iodata import IOData
from ..orbitals import MolecularOrbitals
from ..utils import LineIterator, PrepareDumpError, amu
from ..utils import DumpError, LineIterator, PrepareDumpError, amu

__all__ = []

Expand Down Expand Up @@ -221,7 +221,7 @@ def load_one(lit: LineIterator) -> dict:
if nalpha < 0 or nbeta < 0 or nalpha + nbeta <= 0:
lit.error("The number of electrons is not positive.")
if nalpha < nbeta:
raise ValueError(f"n_alpha={nalpha} < n_beta={nbeta} is not valid!")
lit.error(f"n_alpha={nalpha} < n_beta={nbeta} is not valid!")

norba = fchk["Alpha Orbital Energies"].shape[0]
mo_coeffs = np.copy(fchk["Alpha MO coefficients"].reshape(norba, nbasis).T)
Expand Down Expand Up @@ -643,7 +643,7 @@ def dump_one(f: TextIO, data: IOData):
elif shell.ncon == 2 and shell.angmoms == [0, 1]:
shell_types.append(-1)
else:
raise ValueError("Cannot identify type of shell!")
raise DumpError("Cannot identify type of shell!")

num_pure_d_shells = sum([1 for st in shell_types if st == 2])
num_pure_f_shells = sum([1 for st in shell_types if st == 3])
Expand Down
2 changes: 1 addition & 1 deletion iodata/formats/gamess.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _read_data(lit: LineIterator) -> tuple[str, str, list[str]]:
# The dat file only contains symmetry-unique atoms, so we would be incapable of
# supporting non-C1 symmetry without significant additional coding.
if symmetry != "C1":
raise NotImplementedError(f"Only C1 symmetry is supported. Got {symmetry}")
lit.error(f"Only C1 symmetry is supported. Got {symmetry}")
symbols = []
line = True
while line != " $END \n":
Expand Down
2 changes: 1 addition & 1 deletion iodata/formats/gaussianinput.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def load_one(lit: LineIterator):
if not contents:
break
if len(contents) != 4:
raise ValueError("No Cartesian Structure is detected")
lit.error("No Cartesian Structure is detected")
numbers.append(sym2num[contents[0]])
coor = list(map(float, contents[1:]))
coordinates.append(coor)
Expand Down
7 changes: 2 additions & 5 deletions iodata/formats/gromacs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,7 @@
def load_one(lit: LineIterator) -> dict:
"""Do not edit this docstring. It will be overwritten."""
while True:
try:
data = _helper_read_frame(lit)
except StopIteration:
break
data = _helper_read_frame(lit)
title = data[0]
time = data[1]
resnums = np.array(data[2])
Expand Down Expand Up @@ -75,7 +72,7 @@ def load_many(lit: LineIterator) -> Iterator[dict]:
try:
while True:
yield load_one(lit)
except OSError:
except StopIteration:
return


Expand Down
Loading