Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

142 model download change #150

Merged
merged 20 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions aiida_mlip/calculations/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Base class for features common to most calculations."""

import shutil

from ase.io import read, write

from aiida.common import InputValidationError, datastructures
Expand Down Expand Up @@ -63,10 +65,13 @@ def validate_inputs(
if (
"arch" in inputs
and "model" in inputs
and inputs["arch"].value is not inputs["model"].architecture
and inputs["arch"].value != inputs["model"].architecture
):
inputvalue = inputs["arch"].value
modelvalue = inputs["model"].architecture
raise InputValidationError(
"'arch' in ModelData and in 'arch' input must be the same"
"'arch' in ModelData and in inputs must be the same, "
f"but they are {modelvalue} and {inputvalue}"
)


Expand Down Expand Up @@ -199,8 +204,6 @@ def prepare_for_submission(
An instance of `aiida.common.datastructures.CalcInfo`.
"""

# Create needed inputs

if "struct" in self.inputs:
structure = self.inputs.struct
elif "config" in self.inputs and "struct" in self.inputs.config.as_dictionary:
Expand All @@ -211,8 +214,8 @@ def prepare_for_submission(
# Transform the structure data in xyz file called input_filename
input_filename = self.inputs.metadata.options.input_filename
atoms = structure.get_ase()
# with folder.open(input_filename, mode="w", encoding='utf8') as file:
write(folder.abspath + "/" + input_filename, images=atoms)
with folder.open(input_filename, mode="w", encoding="utf8") as file:
write(file.name, images=atoms)

log_filename = (self.inputs.log_filename).value
cmd_line = {
Expand All @@ -231,7 +234,7 @@ def prepare_for_submission(
# Define architecture from model if model is given,
# otherwise get architecture from inputs and download default model
self._add_arch_to_cmdline(cmd_line)
self._add_model_to_cmdline(cmd_line)
self._add_model_to_cmdline(cmd_line, folder)

if "config" in self.inputs:
# Add config file to command line
Expand Down Expand Up @@ -290,8 +293,7 @@ def _add_arch_to_cmdline(self, cmd_line: dict) -> dict:
cmd_line["arch"] = architecture

def _add_model_to_cmdline(
self,
cmd_line: dict,
self, cmd_line: dict, folder: aiida.common.folders.Folder
) -> dict:
"""
Find model in inputs or config file and add to command line if needed.
Expand All @@ -301,6 +303,9 @@ def _add_model_to_cmdline(
cmd_line : dict
Dictionary containing the cmd line keys.

folder : ~aiida.common.folders.Folder
An `aiida.common.folders.Folder` to temporarily write files on disk.
federicazanca marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
dict
Expand All @@ -311,6 +316,10 @@ def _add_model_to_cmdline(
# Raise error if model is None (different than model not given as input)
if self.inputs.model is None:
raise ValueError("Model cannot be None")
model_path = self.inputs.model.filepath
if model_path:

with self.inputs.model.open(mode="rb") as source:
with folder.open("mlff.model", mode="wb") as target:
federicazanca marked this conversation as resolved.
Show resolved Hide resolved
shutil.copyfileobj(source, target)

model_path = "mlff.model"
cmd_line.setdefault("calc-kwargs", {})["model"] = model_path
8 changes: 5 additions & 3 deletions aiida_mlip/calculations/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Class for training machine learning models."""

from pathlib import Path
import shutil

from aiida.common import InputValidationError, datastructures
import aiida.common.folders
Expand Down Expand Up @@ -175,9 +176,10 @@ def prepare_for_submission(

# Add foundation_model to the config file if fine-tuning is enabled
if self.inputs.fine_tune and "foundation_model" in self.inputs:
model_data = self.inputs.foundation_model
foundation_model_path = model_data.filepath
config_parse += f"\nfoundation_model: {foundation_model_path}"
with self.inputs.foundation_model.open(mode="rb") as source:
with folder.open("mlff.model", mode="wb") as target:
shutil.copyfileobj(source, target)
config_parse += "foundation_model: mlff.model"

# Copy config file content inside the folder where the calculation is run
config_copy = "mlip_train.yml"
Expand Down
145 changes: 58 additions & 87 deletions aiida_mlip/data/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from pathlib import Path
from typing import Any, Optional, Union
from urllib import request
from urllib.parse import urlparse

from aiida.orm import SinglefileData
from aiida.orm import QueryBuilder, SinglefileData, load_node


class ModelData(SinglefileData):
Expand All @@ -26,17 +25,17 @@ class ModelData(SinglefileData):
----------
architecture : str
Architecture of the mlip model.
filepath : str
Path of the mlip model.
model_hash : str
Hash of the model.

Methods
-------
set_file(file, filename=None, architecture=None, **kwargs)
Set the file for the node.
local_file(file, architecture, filename=None):
from_local(file, architecture, filename=None):
Create a ModelData instance from a local file.
download(url, architecture, filename=None, cache_dir=None, force_download=False)
Download a file from a URL and save it as ModelData.
from_uri(uri, architecture, filename=None, cache_dir=None, keep_file=False)
Download a file from a uri and save it as ModelData.

Other Parameters
----------------
Expand Down Expand Up @@ -69,47 +68,6 @@ def _calculate_hash(file: Union[str, Path]) -> str:
file_hash = sha256.hexdigest()
return file_hash

@classmethod
def _check_existing_file(cls, file: Union[str, Path]) -> Path:
"""
Check if a file already exists and return the path of the existing file.

Parameters
----------
file : Union[str, Path]
Path to the downloaded model file.

Returns
-------
Path
The path of the model file of interest (same as input path if no duplicates
were found).
"""
file_hash = cls._calculate_hash(file)

def is_diff_file(curr_path: Path) -> bool:
"""
Filter to check if two files are different.

Parameters
----------
curr_path : Path
Path to the file to compare with.

Returns
-------
bool
True if the files are different, False otherwise.
"""
return curr_path.is_file() and not curr_path.samefile(file)

file_folder = Path(file).parent
for existing_file in filter(is_diff_file, file_folder.rglob("*")):
if cls._calculate_hash(existing_file) == file_hash:
file.unlink()
return existing_file
return Path(file)

def __init__(
self,
file: Union[str, Path],
Expand All @@ -136,7 +94,6 @@ def __init__(
"""
super().__init__(file, filename, **kwargs)
self.base.attributes.set("architecture", architecture)
self.base.attributes.set("filepath", str(file))

def set_file(
self,
Expand Down Expand Up @@ -164,10 +121,12 @@ def set_file(
"""
super().set_file(file, filename, **kwargs)
self.base.attributes.set("architecture", architecture)
self.base.attributes.set("filepath", str(file))
# here compute hash and set attribute
model_hash = self._calculate_hash(file)
self.base.attributes.set("model_hash", model_hash)

@classmethod
def local_file(
def from_local(
cls,
file: Union[str, Path],
architecture: str,
Expand Down Expand Up @@ -195,31 +154,31 @@ def local_file(

@classmethod
# pylint: disable=too-many-arguments
def download(
def from_uri(
cls,
url: str,
uri: str,
architecture: str,
filename: Optional[str] = None,
filename: Optional[str] = "tmp_file.model",
cache_dir: Optional[Union[str, Path]] = None,
force_download: Optional[bool] = False,
keep_file: Optional[bool] = False,
):
"""
Download a file from a URL and save it as ModelData.
Download a file from a uri and save it as ModelData.

Parameters
----------
url : str
URL of the file to download.
uri : str
uri of the file to download.
architecture : [str]
Architecture of the mlip model.
filename : Optional[str], optional
Name to be used for the file (defaults to the name of provided file).
Name to be used for the file defaults to tmp_file.model.
cache_dir : Optional[Union[str, Path]], optional
Path to the folder where the file has to be saved
(defaults to "~/.cache/mlips/").
force_download : Optional[bool], optional
True to keep the downloaded model even if there are duplicates
(default: False).
keep_file : Optional[bool], optional
True to keep the downloaded model, even if there are duplicates.
(default: False, the file is deleted and only saved in the database).

Returns
-------
Expand All @@ -231,32 +190,44 @@ def download(
)
arch_dir = (cache_dir / architecture) if architecture else cache_dir

# cache_path = cache_dir.resolve()
arch_path = arch_dir.resolve()
arch_path.mkdir(parents=True, exist_ok=True)

model_name = urlparse(url).path.split("/")[-1]

file = arch_path / filename if filename else arch_path / model_name

# If file already exists, use next indexed name
stem = file.stem
i = 1
while file.exists():
i += 1
file = file.with_stem(f"{stem}_{i}")
file = arch_path / filename

# Download file
request.urlretrieve(url, file)

if force_download:
print(f"filename changed to {file}")
return cls.local_file(file=file, architecture=architecture)

# Check if the hash of the just downloaded file matches any other file
filepath = cls._check_existing_file(file)

return cls.local_file(file=filepath, architecture=architecture)
request.urlretrieve(uri, file)

model = cls.from_local(file=file, architecture=architecture)

if keep_file:
return model

file.unlink(missing_ok=True)

qb = QueryBuilder()
federicazanca marked this conversation as resolved.
Show resolved Hide resolved
qb.append(ModelData, project=["attributes", "pk", "ctime"])

# Looking for ModelData in the whole database
for i in qb.iterdict():
# If the hash is the same as the new model, but not the creation time
# it means that there is already a model that is the same, use that
if (
"model_hash" in i["ModelData_1"]["attributes"]
and i["ModelData_1"]["attributes"]["model_hash"] == model.model_hash
and i["ModelData_1"]["attributes"]["architecture"] == model.architecture
):
if i["ModelData_1"]["ctime"] != model.ctime:
federicazanca marked this conversation as resolved.
Show resolved Hide resolved
# delete_nodes(
# [model.pk],
# dry_run=False,
# create_forward=True,
# call_calc_forward=True,
# call_work_forward=True,
# )
model = load_node(i["ModelData_1"]["pk"])
break
return model

@property
def architecture(self) -> str:
Expand All @@ -271,13 +242,13 @@ def architecture(self) -> str:
return self.base.attributes.get("architecture")

@property
def filepath(self) -> str:
def model_hash(self) -> str:
"""
Return the filepath.
Return hash of the architecture.

Returns
-------
str
Path of the mlip model.
Hash of the MLIP model.
"""
return self.base.attributes.get("filepath")
return self.base.attributes.get("model_hash")
12 changes: 6 additions & 6 deletions aiida_mlip/helpers/help_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ def load_model(
cache_dir: Optional[Union[str, Path]] = None,
) -> ModelData:
"""
Load a model from a file path or URL.
Load a model from a file path or uri.
federicazanca marked this conversation as resolved.
Show resolved Hide resolved

If the string represents a file path, the model will be loaded from that path.
If it's a URL, the model will be downloaded from the specified location.
If it's a uri, the model will be downloaded from the specified location.
If the input model is None it returns a default model corresponding to the
default used in the Calcjobs.

Parameters
----------
model : Optional[Union[str, Path]]
Model file path or a URL for downloading the model or None to use the default.
Model file path or a uri for downloading the model or None to use the default.
architecture : str
The architecture of the model.
cache_dir : Optional[Union[str, Path]]
Expand All @@ -42,15 +42,15 @@ def load_model(
The loaded model.
"""
if model is None:
loaded_model = ModelData.download(
loaded_model = ModelData.from_uri(
"https://github.com/stfc/janus-core/raw/main/tests/models/mace_mp_small.model", # pylint: disable=line-too-long
architecture,
cache_dir=cache_dir,
)
elif (file_path := Path(model)).is_file():
loaded_model = ModelData.local_file(file_path, architecture=architecture)
loaded_model = ModelData.from_local(file_path, architecture=architecture)
else:
loaded_model = ModelData.download(
loaded_model = ModelData.from_uri(
model, architecture=architecture, cache_dir=cache_dir
)
return loaded_model
Expand Down
4 changes: 2 additions & 2 deletions aiida_mlip/parsers/train_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ def _save_models(self, model_output: Path, compiled_model_output: Path) -> None:
Path to the compiled model output file.
"""
architecture = "mace_mp"
model = ModelData.local_file(model_output, architecture=architecture)
compiled_model = ModelData.local_file(
model = ModelData.from_local(model_output, architecture=architecture)
compiled_model = ModelData.from_local(
compiled_model_output, architecture=architecture
)

Expand Down
Loading
Loading