Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add fine tuning #128

Merged
merged 12 commits into from
May 29, 2024
44 changes: 40 additions & 4 deletions aiida_mlip/calculations/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import aiida.common.folders
from aiida.engine import CalcJob, CalcJobProcessSpec
import aiida.engine.processes
from aiida.orm import Dict, FolderData, SinglefileData
from aiida.orm import Bool, Dict, FolderData, SinglefileData

from aiida_mlip.data.config import JanusConfigfile
from aiida_mlip.data.model import ModelData
Expand Down Expand Up @@ -47,6 +47,15 @@ def validate_inputs(
# Check if the keys actually correspond to a path
if not ((Path(config_file.as_dictionary[key])).resolve()).exists():
raise InputValidationError(f"Path given for {key} does not exist")
# Check if fine-tuning is enabled and validate accordingly
if (
inputs["fine_tune"]
ElliottKasoar marked this conversation as resolved.
Show resolved Hide resolved
and "foundation_model" not in config_file
and "foundation_model" not in inputs
):
raise InputValidationError(
"Undefined Model to fine-tune in inputs or config file"
)


class Train(CalcJob): # numpydoc ignore=PR01
Expand Down Expand Up @@ -89,6 +98,21 @@ def define(cls, spec: CalcJobProcessSpec) -> None:
required=True,
help="Config file with parameters for training",
)

spec.input(
"fine_tune",
valid_type=Bool,
required=False,
default=lambda: Bool(False),
help="Whether fine-tuning a model",
)
spec.input(
"foundation_model",
valid_type=ModelData,
required=False,
help="Model to fine-tune",
)

spec.input(
"metadata.options.output_filename",
valid_type=str,
Expand Down Expand Up @@ -148,6 +172,13 @@ def prepare_for_submission(

# Update the config file with absolute paths
config_parse = config_parse.replace(mlip_dict[file], str(abs_path))

# Add foundation_model to the config file if fine-tuning is enabled
if self.inputs.fine_tune and "foundation_model" in self.inputs:
model_data = self.inputs.foundation_model
foundation_model_path = model_data.filepath
config_parse += f"\nfoundation_model: {foundation_model_path}"

# Copy config file content inside the folder where the calculation is run
config_copy = "mlip_train.yml"
with folder.open(config_copy, "w", encoding="utf-8") as configfile:
Expand All @@ -158,11 +189,16 @@ def prepare_for_submission(
# Initialize cmdline_params with train command
codeinfo.cmdline_params = ["train"]
# Create the rest of the command line
cmd_line = {}
cmd_line["mlip-config"] = config_copy
cmd_line = {"mlip-config": config_copy}
if self.inputs.fine_tune:
cmd_line["fine-tune"] = None

# Add cmd line params to codeinfo
for flag, value in cmd_line.items():
codeinfo.cmdline_params += [f"--{flag}", str(value)]
if value is None:
ElliottKasoar marked this conversation as resolved.
Show resolved Hide resolved
codeinfo.cmdline_params += [f"--{flag}"]
else:
codeinfo.cmdline_params += [f"--{flag}", str(value)]

# Node where the code is saved
codeinfo.code_uuid = self.inputs.code.uuid
Expand Down
27 changes: 27 additions & 0 deletions docs/source/user_guide/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,33 @@ while the other parameters are optional. Here is an example (can be found in the
keep_isolated_atoms: True
save_cpu: True

It is also possible to fine-tune models using the same type of `Calcjob`.
In that case some additional parameters must be used: foundation_model and fine_tune.


.. code-block:: python

inputs = {
code=InstalledCode,
mlip_config=JanusConfigfile,
metadata=Dict({'options': {'output_filename': 'aiida-stdout.txt'}}),
fine_tune=Bool(True),
foundation_model=ModelData
}

TrainCalculation = CalculationFactory("janus.train")
submit(TrainCalculation,inputs)

A model to fine-tune has to be provided as an input, either as a `ModelData` type (in which case it has to be a model file), or in the config file at the keyword `foundation_model`.
If the keyword `fine_tune` is True but no model is given either way, it will return an error.

.. note::

The keyword 'model' and 'foundation_model' refer to two different things.
'foundation_model' is the path to the model to fine-tune (or a shortcut like 'small', etc).
'model' refers to the model-type (see `MACE <https://mace-docs.readthedocs.io/en/latest/>`_ documentation)
federicazanca marked this conversation as resolved.
Show resolved Hide resolved


Submission
^^^^^^^^^^

Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ python = "^3.9"
aiida-core = "^2.5"
ase = "^3.22.1"
voluptuous = "^0.14"
mace-torch = {git = "https://github.com/ACEsuit/mace.git", rev = "develop"}

[tool.poetry.group.dev.dependencies]
coverage = {extras = ["toml"], version = "^7.4.1"}
Expand Down
96 changes: 94 additions & 2 deletions tests/calculations/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,20 @@
import pytest

from aiida.common import InputValidationError, datastructures
from aiida.engine import run
from aiida.orm import Bool
from aiida.plugins import CalculationFactory

from aiida_mlip.data.config import JanusConfigfile
from aiida_mlip.data.model import ModelData

# this is just a temporary solution till mace gets a tag on current main.
try:
alinelena marked this conversation as resolved.
Show resolved Hide resolved
federicazanca marked this conversation as resolved.
Show resolved Hide resolved
from mace.cli.run_train import run as run_train # pylint: disable=unused-import

MACE_IMPORT_ERROR = False
except ImportError:
MACE_IMPORT_ERROR = True


def test_prepare_train(fixture_sandbox, generate_calc_job, janus_code, config_folder):
Expand All @@ -31,8 +43,6 @@ def test_prepare_train(fixture_sandbox, generate_calc_job, janus_code, config_fo
"test_compiled.model",
]

print(sorted(calc_info.retrieve_list))
print(retrieve_list)
# Check the attributes of the returned `CalcInfo`
assert fixture_sandbox.get_content_list() == ["mlip_train.yml"]
assert isinstance(calc_info, datastructures.CalcInfo)
Expand Down Expand Up @@ -97,3 +107,85 @@ def test_noname(
# Restore config file
with open(config_path, "w", encoding="utf-8") as file:
file.writelines(original_lines)


def test_prepare_tune(fixture_sandbox, generate_calc_job, janus_code, config_folder):
"""Test generating fine tuning calculation job."""

model_file = config_folder / "test.model"
entry_point_name = "janus.train"
config_path = config_folder / "mlip_train.yml"
config = JanusConfigfile(file=config_path)
inputs = {
"metadata": {"options": {"resources": {"num_machines": 1}}},
"code": janus_code,
"mlip_config": config,
"fine_tune": Bool(True),
"foundation_model": ModelData.local_file(
file=model_file, architecture="mace_mp"
),
}

calc_info = generate_calc_job(fixture_sandbox, entry_point_name, inputs)

cmdline_params = ["train", "--mlip-config", "mlip_train.yml", "--fine-tune"]

retrieve_list = [
calc_info.uuid,
"aiida-stdout.txt",
"logs",
"results",
"checkpoints",
"test.model",
"test_compiled.model",
]

# Check the attributes of the returned `CalcInfo`
assert fixture_sandbox.get_content_list() == ["mlip_train.yml"]
assert isinstance(calc_info, datastructures.CalcInfo)
assert isinstance(calc_info.codes_info[0], datastructures.CodeInfo)
assert sorted(calc_info.retrieve_list) == sorted(retrieve_list)
assert calc_info.codes_info[0].cmdline_params == cmdline_params


def test_finetune_error(fixture_sandbox, generate_calc_job, janus_code, config_folder):
"""Test error if no model is given."""

entry_point_name = "janus.train"
config_path = config_folder / "mlip_train.yml"
config = JanusConfigfile(file=config_path)
inputs = {
"metadata": {"options": {"resources": {"num_machines": 1}}},
"fine_tune": Bool(True),
"code": janus_code,
"mlip_config": config,
}

with pytest.raises(InputValidationError):
generate_calc_job(fixture_sandbox, entry_point_name, inputs)


@pytest.mark.skipif(MACE_IMPORT_ERROR, reason="Requires updated version of MACE")
def test_run_train(janus_code, config_folder):
"""Test running train with fine-tuning calculation"""

model_file = config_folder / "test.model"
config_path = config_folder / "mlip_train.yml"
config = JanusConfigfile(file=config_path)
inputs = {
"metadata": {"options": {"resources": {"num_machines": 1}}},
"fine_tune": Bool(True),
"code": janus_code,
"mlip_config": config,
"foundation_model": ModelData.local_file(
file=model_file, architecture="mace_mp"
),
}

trainfinetuneCalc = CalculationFactory("janus.train")
result = run(trainfinetuneCalc, **inputs)

assert "results_dict" in result
obtained_res = result["results_dict"].get_dict()
assert "logs" in result
assert obtained_res["loss"] == pytest.approx(0.062798671424389)