Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update dev branch with new CI #230

Merged
merged 5 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 0 additions & 48 deletions .github/workflows/linting.yaml

This file was deleted.

12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,25 @@ fail_fast: true

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
rev: v4.5.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace

- repo: https://github.com/psf/black
rev: 22.10.0
rev: 24.1.1
hooks:
- id: black

- repo: https://github.com/timothycrosley/isort
rev: 5.10.1
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
args: ["--profile", "black", "--filter-files"]

- repo: https://gitlab.com/pycqa/flake8
rev: 5.0.1
- repo: https://github.com/pycqa/flake8
rev: 7.0.0
hooks:
- id: flake8
additional_dependencies: [ flake8-isort ]
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ build:

# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs/source/conf.py
configuration: docs/source/conf.py
2 changes: 1 addition & 1 deletion apax/cli/templates/train_config_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ data:
#train_data_path: <PATH>
#val_data_path: <PATH>
#test_data_path: <PATH>

n_train: 1000
n_valid: 100

Expand Down
14 changes: 6 additions & 8 deletions apax/data/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,12 @@ def dataset_from_dicts(
for key, val in labels["fixed"].items():
labels["fixed"][key] = tf.constant(val)

ds = tf.data.Dataset.from_tensor_slices(
(
inputs["ragged"],
inputs["fixed"],
labels["ragged"],
labels["fixed"],
)
)
ds = tf.data.Dataset.from_tensor_slices((
inputs["ragged"],
inputs["fixed"],
labels["ragged"],
labels["fixed"],
))
return ds


Expand Down
14 changes: 4 additions & 10 deletions apax/train/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,7 @@ def create_single_train_state(params):
return state

if n_models > 1:
train_state_fn = jax.vmap(
create_single_train_state,
axis_name="ensemble"
)
train_state_fn = jax.vmap(create_single_train_state, axis_name="ensemble")
else:
train_state_fn = create_single_train_state

Expand Down Expand Up @@ -129,9 +126,7 @@ def load_params(model_version_path: Path, best=True) -> FrozenDict:
try:
# keep try except block for zntrack load from rev
raw_restored = checkpoints.restore_checkpoint(
model_version_path,
target=None,
step=None
model_version_path, target=None, step=None
)
except FileNotFoundError:
print(f"No checkpoint found at {model_version_path}")
Expand All @@ -143,8 +138,7 @@ def load_params(model_version_path: Path, best=True) -> FrozenDict:


def restore_single_parameters(model_dir: Path) -> Tuple[Config, FrozenDict]:
"""Load the config and parameters of a single model
"""
"""Load the config and parameters of a single model"""
model_dir = Path(model_dir)
model_config = parse_config(model_dir / "config.yaml")

Expand Down Expand Up @@ -200,6 +194,6 @@ def canonicalize_energy_grad_model_parameters(params):

first_level = param_dict["params"]
if "energy_model" not in first_level.keys():
params = {"params": {"energy_model" : first_level}}
params = {"params": {"energy_model": first_level}}
params = freeze(params)
return params
10 changes: 4 additions & 6 deletions apax/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,10 @@ def fit(
epoch_loss["val_loss"] /= val_steps_per_epoch
epoch_loss["val_loss"] = float(epoch_loss["val_loss"])

epoch_metrics.update(
{
f"val_{key}": float(val)
for key, val in val_batch_metrics.compute().items()
}
)
epoch_metrics.update({
f"val_{key}": float(val)
for key, val in val_batch_metrics.compute().items()
})

epoch_metrics.update({**epoch_loss})

Expand Down
2 changes: 1 addition & 1 deletion apax/utils/jax_md_reduced/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ We would like to thank the developers of `jax_md` for the work on this great pac
volume = {33},
year = {2020}
}
```
```
2 changes: 1 addition & 1 deletion docs/source/_tutorials/md_with_ase.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ An ASE calculator of a trained model can be instantiated as follows

CODE

Please refer to the ASE documentation LINK to see how to use ASE calculators.
Please refer to the ASE documentation LINK to see how to use ASE calculators.
5 changes: 1 addition & 4 deletions docs/source/_tutorials/molecular_dynamics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ Congratulations, you have calculated the first observable from a trajectory gene

## Custom Simulation Loops

More complex simulation loops are relatively easy to build yourself in JaxMD (see their colab notebooks for examples).
More complex simulation loops are relatively easy to build yourself in JaxMD (see their colab notebooks for examples).
Trained apax models can of course be used as `energy_fn` in such custom simulations.
If you have a suggestion for adding some MD feature or thermostat to the core of `apax`, feel free to open up an issue on Github LINK.



2 changes: 1 addition & 1 deletion docs/source/_tutorials/training_a_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,4 @@ We provide a separate command for test set evaluation:

TODO pretty print results to the terminal

Congratulations, you have successfully trained and evaluated your fitrst apax model!
Congratulations, you have successfully trained and evaluated your fitrst apax model!
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@

html_theme = "furo"

html_theme_options = {}
html_theme_options = {}
2 changes: 1 addition & 1 deletion docs/source/getting_started/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ Getting Started
.. toctree::
:maxdepth: 2

install
install
2 changes: 1 addition & 1 deletion docs/source/getting_started/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,4 @@ If you want to enable GPU support, please overwrite the jaxlib version:
See the `Jax installation instructions <https://github.com/google/jax#installation>`_ for more details.


.. _Poetry: https://python-poetry.org/
.. _Poetry: https://python-poetry.org/
2 changes: 1 addition & 1 deletion docs/source/modules/md.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ Molecular Dynamics
:members:

.. automodule:: apax.md.nvt
:members:
:members:
2 changes: 1 addition & 1 deletion docs/source/modules/optimizer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ Optimizers
==========

.. automodule:: apax.optimizer.get_optimizer
:members:
:members:
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,3 @@ directory = "coverage_html_report"

[tool.coverage.report]
show_missing = true

12 changes: 5 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,11 @@ def modify_xyz_file(file_path, target_string, replacement_string):

@pytest.fixture()
def get_sample_input():
positions = np.array(
[
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
]
)
positions = np.array([
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
])
atomic_numbers = np.array([1, 1, 8])
box = np.diag(np.zeros(3))
offsets = np.full([3, 3], 0)
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/bal/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ metrics:

loss:
- name: energy
- name: forces
- name: forces
2 changes: 1 addition & 1 deletion tests/integration_tests/md/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ metrics:

loss:
- name: energy
- name: forces
- name: forces
2 changes: 1 addition & 1 deletion tests/integration_tests/md/md_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ duration: 0.25
n_inner: 1
sampling_rate: 1
checkpoint_interval: 2
restart: True
restart: True
12 changes: 5 additions & 7 deletions tests/integration_tests/md/test_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,11 @@ def test_ase_calc(get_tmp_path):
model_config.dump_config(model_config.data.model_version_path)

cell_size = 10.0
positions = np.array(
[
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
]
)
positions = np.array([
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
])
atomic_numbers = np.array([1, 1, 8])
box = np.diag([cell_size] * 3)
offsets = jnp.full([3, 3], 0)
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/transfer_learning/config_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ optimizer:
emb_lr: 0.001
nn_lr: 0.001
scale_lr: 0.0001
shift_lr: 0.001
shift_lr: 0.001
2 changes: 1 addition & 1 deletion tests/integration_tests/transfer_learning/config_ft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ optimizer:
shift_lr: 0.001

checkpoints:
base_model_checkpoint: null
base_model_checkpoint: null
Loading
Loading