diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 668e10c2..54a38114 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.5 + rev: v0.6.1 hooks: - id: ruff args: [--fix] @@ -30,7 +30,7 @@ repos: args: [--check-filenames] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.0 + rev: v1.11.1 hooks: - id: mypy exclude: (tests|examples)/ diff --git a/aviary/train.py b/aviary/train.py index fe71c660..a8ab420a 100644 --- a/aviary/train.py +++ b/aviary/train.py @@ -428,7 +428,7 @@ def checkpoint_model( if checkpoint_endpoint == "local": os.makedirs(f"{ROOT}/models", exist_ok=True) checkpoint_path = ( - f"{ROOT}/models/{timestamp+'-' if timestamp else ''}{run_name}-{epochs}.pth" + f"{ROOT}/models/{timestamp + '-' if timestamp else ''}{run_name}-{epochs}.pth" ) torch.save(checkpoint_dict, checkpoint_path) @@ -438,7 +438,7 @@ def checkpoint_model( ), "can't save model checkpoint to Weights and Biases, wandb.run is None" torch.save( checkpoint_dict, - f"{wandb.run.dir}/{timestamp+'-' if timestamp else ''}{run_name}-{epochs}.pth", + f"{wandb.run.dir}/{timestamp + '-' if timestamp else ''}{run_name}-{epochs}.pth", ) diff --git a/examples/wrenformer/mat_bench/compare_spglib_vs_aflow_wyckoff_labels.py b/examples/wrenformer/mat_bench/compare_spglib_vs_aflow_wyckoff_labels.py index 6382d24e..b6339be4 100644 --- a/examples/wrenformer/mat_bench/compare_spglib_vs_aflow_wyckoff_labels.py +++ b/examples/wrenformer/mat_bench/compare_spglib_vs_aflow_wyckoff_labels.py @@ -2,18 +2,14 @@ import os import pandas as pd +import pymatviz as pmv from matminer.datasets import load_dataset from pymatgen.core import Structure from pymatgen.symmetry.analyzer import SpacegroupAnalyzer -from pymatviz import sankey_from_2_df_cols, spacegroup_sunburst -from pymatviz.utils import crystal_sys_from_spg_num from tqdm import tqdm +import aviary.wren.utils as wren_utils from aviary import ROOT -from aviary.wren.utils import ( - get_protostructure_label_from_aflow, - get_protostructure_label_from_spglib, -) from examples.wrenformer.mat_bench import DATA_PATHS __author__ = "Janosh Riebesell" @@ -31,92 +27,92 @@ # %% -df_perovskites = pd.read_json(DATA_PATHS["matbench_perovskites"]).set_index("mbid") -df_perovskites = df_perovskites.rename(columns={"wyckoff": "spglib_wyckoff"}) -df_perovskites["structure"] = [ - Structure.from_dict(struct) for struct in df_perovskites.structure -] +df_perov = pd.read_json(DATA_PATHS["matbench_perovskites"]).set_index("mbid") +df_perov = df_perov.rename(columns={"wyckoff": "spglib_wyckoff"}) +df_perov["structure"] = df_perov.structure.map(Structure.from_dict) # %% # takes ~6h (when running uninterrupted) -for idx, struct in tqdm(df_perovskites.structure.items(), total=len(df_perovskites)): - if pd.isna(df_perovskites.aflow_wyckoff[idx]): - df_perovskites.loc[idx, "aflow_wyckoff"] = get_protostructure_label_from_aflow( - struct, "/Users/janosh/bin/aflow" +for idx, struct in tqdm(df_perov.structure.items(), total=len(df_perov)): + if pd.isna(df_perov.aflow_wyckoff[idx]): + df_perov.loc[idx, "aflow_wyckoff"] = ( + wren_utils.get_protostructure_label_from_aflow( + struct, "/Users/janosh/bin/aflow" + ) ) # %% # takes ~30 sec -for struct in tqdm(df_perovskites.structure, total=len(df_perovskites)): - get_protostructure_label_from_spglib(struct) +for struct in tqdm(df_perov.structure, total=len(df_perov)): + wren_utils.get_protostructure_label_from_spglib(struct) # %% -df_perovskites.dropna().query("wyckoff != aflow_wyckoff") +df_perov.dropna().query("wyckoff != aflow_wyckoff") # %% print( "Percentage of materials with spglib label != aflow label: " - f"{len(df_perovskites.query('wyckoff != aflow_wyckoff')) / len(df_perovskites):.0%}" + f"{len(df_perov.query('wyckoff != aflow_wyckoff')) / len(df_perov):.0%}" ) # %% -df_perovskites.drop("structure", axis=1).to_csv( +df_perov.drop("structure", axis=1).to_csv( f"{ROOT}/datasets/matbench_perovskites_protostructure_labels.csv" ) # %% -df_perovskites = pd.read_csv( +df_perov = pd.read_csv( f"{ROOT}/datasets/matbench_perovskites_protostructure_labels.csv" ).set_index("mbid") # %% for src in ("aflow", "spglib"): - df_perovskites[f"{src}_spg_num"] = ( - df_perovskites[f"{src}_wyckoff"].str.split("_").str[2].astype(int) + df_perov[f"{src}_spg_num"] = ( + df_perov[f"{src}_wyckoff"].str.split("_").str[2].astype(int) ) # %% -fig = spacegroup_sunburst(df_perovskites.spglib_spg) +fig = pmv.spacegroup_sunburst(df_perov.spglib_spg) fig.update_layout(title=dict(text="Spglib Spacegroups", x=0.5, y=0.93)) # fig.write_image(f"{MODULE_DIR}/plots/matbench_perovskites_aflow_sunburst.pdf") # %% -fig = spacegroup_sunburst(df_perovskites.aflow_spg, title="Aflow") +fig = pmv.spacegroup_sunburst(df_perov.aflow_spg, title="Aflow") fig.update_layout(title=dict(text="Aflow Spacegroups", x=0.5, y=0.85)) # fig.write_image(f"{MODULE_DIR}/plots/matbench_perovskites_spglib_sunburst.pdf") # %% -df_perovskites = load_dataset("matbench_perovskites") +df_perov = load_dataset("matbench_perovskites") -df_perovskites["spglib_spg_num"] = df_perovskites.structure.map( +df_perov["spglib_spg_num"] = df_perov.structure.map( lambda struct: SpacegroupAnalyzer(struct).get_space_group_number() ) # %% for src in ("aflow", "spglib"): - df_perovskites[f"{src}_crys_sys"] = df_perovskites[f"{src}_spg_num"].map( - crystal_sys_from_spg_num + df_perov[f"{src}_crys_sys"] = df_perov[f"{src}_spg_num"].map( + pmv.utils.crystal_sys_from_spg_num ) # %% -fig = sankey_from_2_df_cols(df_perovskites, ["aflow_spg_num", "spglib_spg_num"]) +fig = pmv.sankey_from_2_df_cols(df_perov, ["aflow_spg_num", "spglib_spg_num"]) fig.update_layout(title="Matbench Perovskites Aflow vs Spglib Spacegroups") # %% -fig = sankey_from_2_df_cols(df_perovskites, ["aflow_crys_sys", "spglib_crys_sys"]) +fig = pmv.sankey_from_2_df_cols(df_perov, ["aflow_crys_sys", "spglib_crys_sys"]) fig.update_layout(title="Aflow vs Spglib Crystal Systems") diff --git a/examples/wrenformer/mat_bench/make_plots.py b/examples/wrenformer/mat_bench/make_plots.py index 5cf9cdd6..f324353b 100644 --- a/examples/wrenformer/mat_bench/make_plots.py +++ b/examples/wrenformer/mat_bench/make_plots.py @@ -10,10 +10,10 @@ import pandas as pd import plotly.express as px +import pymatviz as pmv from matbench import MatbenchBenchmark from matbench.constants import CLF_KEY, REG_KEY from matbench.metadata import mbv01_metadata as matbench_metadata -from pymatviz.powerups import add_identity_line from sklearn.metrics import r2_score, roc_auc_score from examples.wrenformer.mat_bench import DATA_PATHS @@ -209,7 +209,7 @@ "value ": "Predicted formation energy (eV/atom)", }, ) -add_identity_line(fig) +pmv.powerups.add_identity_line(fig) fig.update_layout(legend=dict(x=0.02, y=0.95, xanchor="left", title="Models")) diff --git a/examples/wrenformer/mat_bench/utils.py b/examples/wrenformer/mat_bench/utils.py index 0d130b67..85d37420 100644 --- a/examples/wrenformer/mat_bench/utils.py +++ b/examples/wrenformer/mat_bench/utils.py @@ -14,14 +14,14 @@ def _int_keys(dct: dict) -> dict: return {int(k) if k.lstrip("-").isdigit() else k: v for k, v in dct.items()} -def recursive_dict_merge(d1: dict, d2: dict) -> dict: +def recursive_dict_merge(dict1: dict, dict2: dict) -> dict: """Merge two dicts recursively.""" - for key in d2: - if key in d1 and isinstance(d1[key], dict) and isinstance(d2[key], dict): - recursive_dict_merge(d1[key], d2[key]) + for key, val2 in dict2.items(): + if key in dict1 and isinstance(dict1[key], dict) and isinstance(val2, dict): + recursive_dict_merge(dict1[key], val2) else: - d1[key] = d2[key] - return d1 + dict1[key] = val2 + return dict1 def merge_json_on_disk( diff --git a/pyproject.toml b/pyproject.toml index a930018d..fa0e42d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,8 +73,10 @@ no_implicit_optional = false [tool.ruff] line-length = 90 target-version = "py39" -extend-include = ["*.ipynb"] -lint.select = [ +output-format = "concise" + +[tool.ruff.lint] +select = [ "B", # flake8-bugbear "C4", # flake8-comprehensions "D", # pydocstyle @@ -105,7 +107,7 @@ lint.select = [ "W", # pycodestyle warning "YTT", # flake8-2020 ] -lint.ignore = [ +ignore = [ "C408", # Unnecessary dict call - rewrite as a literal "D100", # Missing docstring in public module "D104", # Missing docstring in public package @@ -116,8 +118,8 @@ lint.ignore = [ "PLR", # pylint refactor "PT006", # pytest-parametrize-names-wrong-type ] -lint.pydocstyle.convention = "google" -lint.isort.known-third-party = ["wandb"] +pydocstyle.convention = "google" +isort.known-third-party = ["wandb"] [tool.ruff.lint.per-file-ignores] "tests/*" = ["D"]