Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 27, 2024
1 parent d824ddc commit 042a768
Showing 1 changed file with 41 additions and 41 deletions.
82 changes: 41 additions & 41 deletions sevenn/train/dataload.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _graph_build_matscipy(cutoff: float, pbc, cell, pos):
cell[2, :] = max_positions * 5 * cutoff * identity[2, :]
# it does not have self-interaction
edge_src, edge_dst, edge_vec, shifts = neighbour_list(
quantities="ijDS",
quantities='ijDS',
pbc=pbc,
cell=cell,
positions=pos,
Expand All @@ -63,7 +63,7 @@ def _graph_build_matscipy(cutoff: float, pbc, cell, pos):
def _graph_build_ase(cutoff: float, pbc, cell, pos):
# building neighbor list
edge_src, edge_dst, edge_vec, shifts = primitive_neighbor_list(
"ijDS", pbc, cell, pos, cutoff, self_interaction=True
'ijDS', pbc, cell, pos, cutoff, self_interaction=True
)

is_zero_idx = np.all(edge_vec == 0, axis=1)
Expand Down Expand Up @@ -155,9 +155,9 @@ def atoms_to_graph(
Requires grad is handled by 'dataset' not here.
"""
if not y_from_calc:
y_energy = atoms.info["y_energy"]
y_force = atoms.arrays["y_force"]
y_stress = atoms.info.get("y_stress", np.full((6,), np.nan))
y_energy = atoms.info['y_energy']
y_force = atoms.arrays['y_force']
y_stress = atoms.info.get('y_stress', np.full((6,), np.nan))
if y_stress.shape == (3, 3):
y_stress = np.array(
[
Expand All @@ -182,7 +182,7 @@ def atoms_to_graph(
y_stress = np.array(y_stress[[0, 1, 2, 5, 3, 4]])
except RuntimeError:
y_stress = np.full((6,), np.nan)
assert y_stress.shape == (6,), "If you see this, please report to the maintainer"
assert y_stress.shape == (6,), 'If you see this, please report to the maintainer'

pos = atoms.get_positions()
cell = np.array(atoms.get_cell())
Expand Down Expand Up @@ -218,12 +218,12 @@ def atoms_to_graph(
info = copy.deepcopy(atoms.info)
# save only metadata
# TODO: is it really necessary?
if "y_energy" in info:
del info["y_energy"]
if "y_force" in info:
del info["y_force"]
if "y_stress" in info:
del info["y_stress"]
if 'y_energy' in info:
del info['y_energy']
if 'y_force' in info:
del info['y_force']
if 'y_stress' in info:
del info['y_stress']
data[KEY.INFO] = info

else:
Expand Down Expand Up @@ -265,7 +265,7 @@ def graph_build(
pool.join()
else:
graph_list = [
atoms_to_graph(*input_) for input_ in tqdm(inputs, desc="graph_build (1)")
atoms_to_graph(*input_) for input_ in tqdm(inputs, desc='graph_build (1)')
]

graph_list = [AtomGraphData.from_numpy_dict(g) for g in graph_list]
Expand Down Expand Up @@ -306,25 +306,25 @@ def _set_atoms_y(
for atoms in atoms_list:
# access energy
if energy_key is not None:
atoms.info["y_energy"] = atoms.info[energy_key]
atoms.info['y_energy'] = atoms.info[energy_key]
del atoms.info[energy_key]
else:
try:
atoms.info["y_energy"] = atoms.get_potential_energy(
atoms.info['y_energy'] = atoms.get_potential_energy(
force_consistent=True
)
except NotImplementedError:
atoms.info["y_energy"] = atoms.get_potential_energy()
atoms.info['y_energy'] = atoms.get_potential_energy()
# access force
if force_key is not None:
atoms.arrays["y_force"] = atoms.arrays[force_key]
atoms.arrays['y_force'] = atoms.arrays[force_key]
del atoms.arrays[force_key]
else:
atoms.arrays["y_force"] = atoms.get_forces(apply_constraint=False)
atoms.arrays['y_force'] = atoms.get_forces(apply_constraint=False)
# access stress
if stress_key is not None:
y_stress = -1 * atoms.info[stress_key]
atoms.info["y_stress"] = np.array(y_stress[[0, 1, 2, 5, 3, 4]])
atoms.info['y_stress'] = np.array(y_stress[[0, 1, 2, 5, 3, 4]])
del atoms.info[stress_key]
else:
try:
Expand All @@ -333,9 +333,9 @@ def _set_atoms_y(
# (ASE automatically converts vasp kB to eV/A^3)
# So we restore it
y_stress = -1 * atoms.get_stress()
atoms.info["y_stress"] = np.array(y_stress[[0, 1, 2, 5, 3, 4]])
atoms.info['y_stress'] = np.array(y_stress[[0, 1, 2, 5, 3, 4]])
except RuntimeError:
atoms.info["y_stress"] = np.full((6,), np.nan)
atoms.info['y_stress'] = np.full((6,), np.nan)
return atoms_list


Expand All @@ -344,7 +344,7 @@ def ase_reader(
energy_key: Optional[str] = None,
force_key: Optional[str] = None,
stress_key: Optional[str] = None,
index: str = ":",
index: str = ':',
**kwargs,
) -> list[ase.Atoms]:
"""
Expand All @@ -358,7 +358,7 @@ def ase_reader(


# Reader
def structure_list_reader(filename: str, format_outputs="vasp-out"):
def structure_list_reader(filename: str, format_outputs='vasp-out'):
"""
Deprecated
Read from structure_list using braceexpand and ASE
Expand All @@ -377,27 +377,27 @@ def structure_list_reader(filename: str, format_outputs="vasp-out"):

def parse_label(line):
line = line.strip()
if line.startswith("[") is False:
if line.startswith('[') is False:
return False
elif line.endswith("]") is False:
raise ValueError("wrong structure_list title format")
elif line.endswith(']') is False:
raise ValueError('wrong structure_list title format')
return line[1:-1]

def parse_fileline(line):
line = line.strip().split()
if len(line) == 1:
line.append(":")
line.append(':')
elif len(line) != 2:
raise ValueError("wrong structure_list format")
raise ValueError('wrong structure_list format')
return line[0], line[1]

structure_list_file = open(filename, "r")
structure_list_file = open(filename, 'r')
lines = structure_list_file.readlines()

raw_str_dict = {}
label = "Default"
label = 'Default'
for line in lines:
if line.strip() == "":
if line.strip() == '':
continue
tmp_label = parse_label(line)
if tmp_label:
Expand All @@ -408,36 +408,36 @@ def parse_fileline(line):
files_expr, index_expr = parse_fileline(line)
raw_str_dict[label].append((files_expr, index_expr))
else:
raise ValueError("wrong structure_list format")
raise ValueError('wrong structure_list format')
structure_list_file.close()

structures_dict = {}
info_dct = {"data_from": "user_OUTCAR"}
info_dct = {'data_from': 'user_OUTCAR'}
for title, file_lines in raw_str_dict.items():
stct_lists = []
for file_line in file_lines:
files_expr, index_expr = file_line
index = string2index(index_expr)
for expanded_filename in list(braceexpand(files_expr)):
f_stream = open(expanded_filename, "r")
f_stream = open(expanded_filename, 'r')
# generator of all outcar ionic steps
gen_all = outcarchunks(f_stream, ocp)
try: # TODO: index may not slice, it can be integer
it_atoms = islice(gen_all, index.start, index.stop, index.step)
except ValueError:
# TODO: support
# negative index
raise ValueError("Negative index is not supported yet")
raise ValueError('Negative index is not supported yet')

info_dct_f = {
**info_dct,
"file": os.path.abspath(expanded_filename),
'file': os.path.abspath(expanded_filename),
}
for idx, o in enumerate(it_atoms):
try:
istep = index.start + idx * index.step
atoms = o.build()
atoms.info = {**info_dct_f, "ionic_step": istep}
atoms.info = {**info_dct_f, 'ionic_step': istep}
except TypeError: # it is not slice of ionic steps
atoms = o.build()
atoms.info = info_dct_f
Expand All @@ -450,12 +450,12 @@ def parse_fileline(line):
def match_reader(reader_name: str, **kwargs):
reader = None
metadata = {}
if reader_name == "structure_list":
if reader_name == 'structure_list':
reader = partial(structure_list_reader, **kwargs)
metadata.update({"origin": "structure_list"})
metadata.update({'origin': 'structure_list'})
else:
reader = partial(ase_reader, **kwargs)
metadata.update({"origin": "ase_reader"})
metadata.update({'origin': 'ase_reader'})
return reader, metadata


Expand Down Expand Up @@ -486,7 +486,7 @@ def file_to_dataset(
elif isinstance(atoms, dict):
atoms_dct = atoms
else:
raise TypeError("The return of reader is not list or dict")
raise TypeError('The return of reader is not list or dict')

graph_dct = {}
for label, atoms_list in atoms_dct.items():
Expand Down

0 comments on commit 042a768

Please sign in to comment.