Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moredocs nico #242

Merged
merged 15 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions apax/md/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def atoms_from_state(self, state, energy, nbr_kwargs):

atoms = Atoms(self.atomic_numbers, positions, momenta=momenta, cell=box)
atoms.cell = atoms.cell.T
atoms.pbc = np.diag(atoms.cell.array) > 1e-7
atoms.pbc = np.diag(atoms.cell.array) > 1e-6
atoms.calc = SinglePointCalculator(atoms, energy=float(energy), forces=forces)
return atoms

Expand All @@ -66,7 +66,7 @@ def __init__(
) -> None:
self.atomic_numbers = system.atomic_numbers
self.box = system.box
self.fractional = np.any(self.box < 1e-6)
self.fractional = np.any(self.box > 1e-6)
self.sampling_rate = sampling_rate
self.traj_path = traj_path
self.db = znh5md.io.DataWriter(self.traj_path)
Expand Down
6 changes: 4 additions & 2 deletions apax/md/nvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,10 @@ def md_setup(model_config: Config, md_config: MDConfig):
r_max = model_config.model.r_max
log.info("initializing model")
if np.all(system.box < 1e-6):
frac_coords = False
displacement_fn, shift_fn = space.free()
else:
frac_coords = True
heights = heights_of_box_sides(system.box)

if np.any(atoms.cell.lengths() / 2 < r_max):
Expand All @@ -356,7 +358,7 @@ def md_setup(model_config: Config, md_config: MDConfig):
"can not calculate the correct neighbors",
)
displacement_fn, shift_fn = space.periodic_general(
system.box, fractional_coordinates=True
system.box, fractional_coordinates=frac_coords
)

builder = ModelBuilder(model_config.model.get_dict())
Expand All @@ -368,7 +370,7 @@ def md_setup(model_config: Config, md_config: MDConfig):
system.box,
r_max,
md_config.dr_threshold,
fractional_coordinates=True,
fractional_coordinates=frac_coords,
format=partition.Sparse,
disable_cell_list=True,
)
Expand Down
15 changes: 8 additions & 7 deletions apax/train/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@ def load_test_data(
): # TODO double code run.py in progress
log.info("Running Input Pipeline")
os.makedirs(eval_path, exist_ok=True)
if config.data.data_path is not None:

if config.data.test_data_path is not None:
log.info(f"Read test data file {config.data.test_data_path}")
atoms_list = load_data(config.data.test_data_path)
atoms_list = atoms_list[:n_test]

elif config.data.data_path is not None:
log.info(f"Read data file {config.data.data_path}")
atoms_list = load_data(config.data.data_path)

Expand All @@ -54,12 +60,6 @@ def load_test_data(

atoms_list, _ = split_atoms(atoms_list, test_idxs)

elif config.data.test_data_path is not None:
log.info(f"Read test data file {config.data.test_data_path}")
atoms_list, label_dict = load_data(config.data.test_data_path)
atoms_list = atoms_list[:n_test]
for key, val in label_dict.items():
label_dict[key] = val[:n_test]
else:
raise ValueError("input data path/paths not defined")

Expand All @@ -80,6 +80,7 @@ def predict(model, params, Metrics, loss_fn, test_ds, callbacks, is_ensemble=Fal
0, test_ds.n_data, desc="Structure", ncols=100, disable=False, leave=True
)
for batch_idx in range(test_ds.n_data):
callbacks.on_test_batch_begin(batch_idx)
batch = next(batch_test_ds)
batch_start_time = time.time()

Expand Down
17 changes: 17 additions & 0 deletions apax/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,23 @@ def download_benzene_DFT(data_path):
return new_file_path


def download_etoh_ccsdt(data_path):
url = "http://www.quantum-machine.org/gdml/data/xyz/ethanol_ccsd_t.zip"
file_path = data_path / "ethanol_ccsd_t.zip"

os.makedirs(data_path, exist_ok=True)
urllib.request.urlretrieve(url, file_path)

with zipfile.ZipFile(file_path, "r") as zip_ref:
zip_ref.extractall(data_path)

test_file_path = data_path / "ethanol_ccsd_t-test.xyz"
train_file_path = data_path / "ethanol_ccsd_t-train.xyz"
os.remove(file_path)

return train_file_path, test_file_path


def download_md22_benzene_CCSDT(data_path):
url = "http://www.quantum-machine.org/gdml/data/xyz/benzene_ccsd_t.zip"
file_path = data_path / "benzene_ccsdt.zip"
Expand Down
30 changes: 28 additions & 2 deletions apax/utils/helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import yaml
import csv


def setup_ase():
Expand All @@ -17,8 +18,33 @@ def mod_config(config_path, updated_config):
config_dict = yaml.safe_load(stream)

for key, new_value in updated_config.items():
if isinstance(config_dict[key], dict):
config_dict[key].update(new_value)
if key in config_dict.keys():
if isinstance(config_dict[key], dict):
config_dict[key].update(new_value)
else:
config_dict[key] = new_value
else:
config_dict[key] = new_value
return config_dict


def load_csv_metrics(path):
data_dict = {}

with open(path, "r") as file:
reader = csv.reader(file)

# Extract the headers (keys) from the first row
headers = next(reader)

# Initialize empty lists for each key
for header in headers:
data_dict[header] = []

# Read the rest of the rows and append values to the corresponding key
for row in reader:
for idx, value in enumerate(row):
key = headers[idx]
data_dict[key].append(float(value))

return data_dict
Loading
Loading