Skip to content

Commit

Permalink
Merge pull request #258 from apax-hub/fix-units
Browse files Browse the repository at this point in the history
unit fix
  • Loading branch information
Tetracarbonylnickel authored Apr 4, 2024
2 parents 8f6faa7 + baa6db0 commit f3efd67
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 18 deletions.
2 changes: 2 additions & 0 deletions apax/bal/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def kernel_selection(
bs=processing_batch_size,
n_epochs=1,
ignore_labels=True,
pos_unit=config.data.pos_unit,
energy_unit=config.data.energy_unit,
)

_, init_box = dataset.init_input()
Expand Down
30 changes: 17 additions & 13 deletions apax/data/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import tensorflow as tf

from apax.data.preprocessing import compute_nl, prefetch_to_single_device
from apax.utils.convert import atoms_to_inputs, atoms_to_labels
from apax.utils.convert import atoms_to_inputs, atoms_to_labels, unit_dict

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -44,33 +44,37 @@ def __init__(
n_epochs,
buffer_size=1000,
n_jit_steps=1,
pos_unit: str = "Ang",
energy_unit: str = "eV",
pre_shuffle=False,
ignore_labels=False,
cache_path=".",
) -> None:
if pre_shuffle:
shuffle(atoms)
self.sample_atoms = atoms[0]
self.inputs = atoms_to_inputs(atoms)

self.n_epochs = n_epochs
self.cutoff = cutoff
self.n_jit_steps = n_jit_steps
self.buffer_size = buffer_size
self.n_data = len(atoms)
self.batch_size = self.validate_batch_size(bs)
self.pos_unit = pos_unit

max_atoms, max_nbrs = find_largest_system(self.inputs, cutoff)
if pre_shuffle:
shuffle(atoms)
self.sample_atoms = atoms[0]
self.inputs = atoms_to_inputs(atoms, self.pos_unit)

max_atoms, max_nbrs = find_largest_system(self.inputs, self.cutoff)
self.max_atoms = max_atoms
self.max_nbrs = max_nbrs

if atoms[0].calc and not ignore_labels:
self.labels = atoms_to_labels(atoms)
self.labels = atoms_to_labels(atoms, self.pos_unit, energy_unit)
else:
self.labels = None

self.n_data = len(atoms)
self.count = 0
self.cutoff = cutoff
self.buffer = deque()
self.batch_size = self.validate_batch_size(bs)
self.n_jit_steps = n_jit_steps
self.file = Path(cache_path) / str(uuid.uuid4())

self.enqueue(min(self.buffer_size, self.n_data))
Expand Down Expand Up @@ -164,8 +168,8 @@ def make_signature(self) -> tf.TensorSpec:

def init_input(self) -> Dict[str, np.ndarray]:
"""Returns first batch of inputs and labels to init the model."""
positions = self.sample_atoms.positions
box = self.sample_atoms.cell.array
positions = self.sample_atoms.positions * unit_dict[self.pos_unit]
box = self.sample_atoms.cell.array * unit_dict[self.pos_unit]
idx, offsets = compute_nl(positions, box, self.cutoff)
inputs = (
positions,
Expand Down
6 changes: 5 additions & 1 deletion apax/train/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,11 @@ def eval_model(config_path, n_test=-1, log_file="eval.log", log_level="error"):

atoms_list = load_test_data(config, model_version_path, eval_path, n_test)
test_ds = OTFInMemoryDataset(
atoms_list, config.model.r_max, config.data.valid_batch_size
atoms_list,
config.model.r_max,
config.data.valid_batch_size,
pos_unit=config.data.pos_unit,
energy_unit=config.data.energy_unit,
)

_, init_box = test_ds.init_input()
Expand Down
4 changes: 4 additions & 0 deletions apax/train/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def initialize_datasets(config: Config):
config.n_epochs,
config.data.shuffle_buffer_size,
config.n_jitted_steps,
config.data.pos_unit,
config.data.energy_unit,
pre_shuffle=True,
cache_path=config.data.model_version_path,
)
Expand All @@ -71,6 +73,8 @@ def initialize_datasets(config: Config):
config.model.r_max,
config.data.valid_batch_size,
config.n_epochs,
pos_unit=config.data.pos_unit,
energy_unit=config.data.energy_unit,
cache_path=config.data.model_version_path,
)
ds_stats = compute_scale_shift_parameters(
Expand Down
9 changes: 5 additions & 4 deletions tests/regression_tests/apax_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,14 @@ metrics:
# - mse

loss:
- loss_type: structures
name: energy
- name: energy
atoms_exponent: 2
weight: 1.0
- loss_type: structures
name: forces
- name: forces
atoms_exponent: 1
weight: 8.0
- loss_type: cosine_sim
atoms_exponent: 1
name: forces
weight: 0.1
# - loss_type: structures
Expand Down

0 comments on commit f3efd67

Please sign in to comment.