Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

142 model download change #150

Merged
merged 20 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aiida_mlip/data/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ModelData(SinglefileData):
Set the file for the node.
from_local(file, architecture, filename=None):
Create a ModelData instance from a local file.
from_url(url, architecture, filename=None, cache_dir=None, keep_file=False)
from_uri(url, architecture, filename=None, cache_dir=None, keep_file=False)
federicazanca marked this conversation as resolved.
Show resolved Hide resolved
Download a file from a URL and save it as ModelData.

Other Parameters
Expand Down Expand Up @@ -155,7 +155,7 @@ def from_local(

@classmethod
# pylint: disable=too-many-arguments
def from_url(
def from_uri(
cls,
url: str,
architecture: str,
Expand Down
4 changes: 2 additions & 2 deletions aiida_mlip/helpers/help_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ def load_model(
The loaded model.
"""
if model is None:
loaded_model = ModelData.from_url(
loaded_model = ModelData.from_uri(
"https://github.com/stfc/janus-core/raw/main/tests/models/mace_mp_small.model", # pylint: disable=line-too-long
architecture,
cache_dir=cache_dir,
)
elif (file_path := Path(model)).is_file():
loaded_model = ModelData.from_local(file_path, architecture=architecture)
else:
loaded_model = ModelData.from_url(
loaded_model = ModelData.from_uri(
model, architecture=architecture, cache_dir=cache_dir
)
return loaded_model
Expand Down
2 changes: 1 addition & 1 deletion docs/source/user_guide/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Usage

.. code-block:: python

model = ModelData.from_url('http://yoururl.test/model', architecture='mace', filename='model', cache_dir='/home/mlip/', force_download=False)
model = ModelData.from_uri('http://yoururl.test/model', architecture='mace', filename='model', cache_dir='/home/mlip/', force_download=False)

- The architecture of the model file can be accessed using the `architecture` property:

Expand Down
2 changes: 1 addition & 1 deletion docs/source/user_guide/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ In this example we use MACE with a model that we download from this URL: "https:

from aiida_mlip.data.model import ModelData
url = "https://github.com/stfc/janus-core/raw/main/tests/models/mace_mp_small.model"
model = ModelData.from_url(url, architecture="mace", cache_dir="/.cache/")
model = ModelData.from_uri(url, architecture="mace", cache_dir="/.cache/")

If we already have the model saved in some folder we can save it as:

Expand Down
4 changes: 2 additions & 2 deletions examples/tutorials/geometry-optimisation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
"source": [
"from aiida_mlip.data.model import ModelData\n",
"url = \"https://github.com/stfc/janus-core/raw/main/tests/models/mace_mp_small.model\"\n",
"model = ModelData.from_url(url, architecture=\"mace_mp\", cache_dir=\"mlips\")"
"model = ModelData.from_uri(url, architecture=\"mace_mp\", cache_dir=\"mlips\")"
]
},
{
Expand Down Expand Up @@ -368,7 +368,7 @@
" return traj.get_step_structure(index.value)\n",
"\n",
"url = \"https://github.com/stfc/janus-core/raw/main/tests/models/mace_mp_small.model\"\n",
"model = ModelData.from_url(url, architecture=\"mace_mp\", cache_dir=\"mlips\")\n",
"model = ModelData.from_uri(url, architecture=\"mace_mp\", cache_dir=\"mlips\")\n",
"list_of_nodes = []\n",
"\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/tutorials/singlepoint.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
"source": [
"from aiida_mlip.data.model import ModelData\n",
"url = \"https://github.com/stfc/janus-core/raw/main/tests/models/mace_mp_small.model\"\n",
"model = ModelData.from_url(url, architecture=\"mace_mp\", cache_dir=\"mlips\")"
"model = ModelData.from_uri(url, architecture=\"mace_mp\", cache_dir=\"mlips\")"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions tests/calculations/test_geomopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_geomopt(fixture_sandbox, generate_calc_job, janus_code, model_folder):
"--out",
"aiida-results.xyz",
"--calc-kwargs",
"{'default_dtype': 'float64', 'model': 'modelcopy.model'}",
"{'default_dtype': 'float64', 'model': 'mlff.model'}",
"--traj",
"aiida-traj.xyz",
]
Expand All @@ -58,7 +58,7 @@ def test_geomopt(fixture_sandbox, generate_calc_job, janus_code, model_folder):

# Check the attributes of the returned `CalcInfo`
assert sorted(fixture_sandbox.get_content_list()) == sorted(
["aiida.xyz", "modelcopy.model"]
["aiida.xyz", "mlff.model"]
)
assert isinstance(calc_info, datastructures.CalcInfo)
assert isinstance(calc_info.codes_info[0], datastructures.CodeInfo)
Expand Down
8 changes: 4 additions & 4 deletions tests/calculations/test_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_MD(fixture_sandbox, generate_calc_job, janus_code, model_folder):
"--summary",
"md_summary.yml",
"--calc-kwargs",
"{'default_dtype': 'float64', 'model': 'modelcopy.model'}",
"{'default_dtype': 'float64', 'model': 'mlff.model'}",
"--ensemble",
"nve",
"--temp",
Expand Down Expand Up @@ -86,7 +86,7 @@ def test_MD(fixture_sandbox, generate_calc_job, janus_code, model_folder):

# Check the attributes of the returned `CalcInfo`
assert sorted(fixture_sandbox.get_content_list()) == sorted(
["aiida.xyz", "modelcopy.model"]
["aiida.xyz", "mlff.model"]
)
assert isinstance(calc_info, datastructures.CalcInfo)
assert isinstance(calc_info.codes_info[0], datastructures.CodeInfo)
Expand Down Expand Up @@ -125,7 +125,7 @@ def test_MD_with_config(
"--arch",
"mace",
"--calc-kwargs",
"{'model': 'modelcopy.model'}",
"{'model': 'mlff.model'}",
"--config",
"config.yaml",
"--ensemble",
Expand All @@ -152,7 +152,7 @@ def test_MD_with_config(
[
"aiida.xyz",
"config.yaml",
"modelcopy.model",
"mlff.model",
]
)
assert isinstance(calc_info, datastructures.CalcInfo)
Expand Down
4 changes: 2 additions & 2 deletions tests/calculations/test_singlepoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_singlepoint(fixture_sandbox, generate_calc_job, janus_code, model_folde
"--out",
"aiida-results.xyz",
"--calc-kwargs",
"{'default_dtype': 'float64', 'model': 'modelcopy.model'}",
"{'default_dtype': 'float64', 'model': 'mlff.model'}",
]

retrieve_list = [
Expand All @@ -56,7 +56,7 @@ def test_singlepoint(fixture_sandbox, generate_calc_job, janus_code, model_folde

# Check the attributes of the returned `CalcInfo`
assert sorted(fixture_sandbox.get_content_list()) == sorted(
["aiida.xyz", "modelcopy.model"]
["aiida.xyz", "mlff.model"]
)
assert isinstance(calc_info, datastructures.CalcInfo)
assert isinstance(calc_info.codes_info[0], datastructures.CodeInfo)
Expand Down
2 changes: 1 addition & 1 deletion tests/calculations/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def test_prepare_tune(fixture_sandbox, generate_calc_job, janus_code, config_fol

# Check the attributes of the returned `CalcInfo`
assert sorted(fixture_sandbox.get_content_list()) == sorted(
["mlip_train.yml", "modelcopy.model"]
["mlip_train.yml", "mlff.model"]
)
assert isinstance(calc_info, datastructures.CalcInfo)
assert isinstance(calc_info.codes_info[0], datastructures.CodeInfo)
Expand Down
8 changes: 4 additions & 4 deletions tests/data/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_download_fresh_file_keep(tmp_path):

# Construct a ModelData instance downloading a non-cached file
# pylint:disable=line-too-long
model = ModelData.from_url(
model = ModelData.from_uri(
url="https://github.com/stfc/janus-core/raw/main/tests/models/mace_mp_small.model",
filename="mace.model",
cache_dir=tmp_path,
Expand All @@ -68,7 +68,7 @@ def test_download_fresh_file(tmp_path):

# Construct a ModelData instance downloading a non-cached file
# pylint:disable=line-too-long
model = ModelData.from_url(
model = ModelData.from_uri(
url="https://github.com/stfc/janus-core/raw/main/tests/models/mace_mp_small.model",
filename="mace.model",
cache_dir=tmp_path,
Expand All @@ -85,15 +85,15 @@ def test_no_download_cached_file(tmp_path):
"""Test if the caching prevents saving duplicate model in the database."""

# pylint:disable=line-too-long
existing_model = ModelData.from_url(
existing_model = ModelData.from_uri(
url="https://github.com/stfc/janus-core/raw/main/tests/models/mace_mp_small.model",
filename="mace_existing.model",
cache_dir=tmp_path,
architecture="mace_mp",
)
# Construct a ModelData instance that should use the cached file
# pylint:disable=line-too-long
model = ModelData.from_url(
model = ModelData.from_uri(
url="https://github.com/stfc/janus-core/raw/main/tests/models/mace_mp_small.model",
cache_dir=tmp_path,
filename="test_model.model",
Expand Down