diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index a0694c41c5..7c8a95c5e7 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -239,16 +239,27 @@ def get_backend_info(self) -> dict: } -def train(FLAGS): - log.info("Configuration path: %s", FLAGS.INPUT) +def train( + input_file: str, + init_model: Optional[str], + restart: Optional[str], + finetune: Optional[str], + init_frz_model: Optional[str], + model_branch: str, + skip_neighbor_stat: bool = False, + use_pretrain_script: bool = False, + force_load: bool = False, + output: str = "out.json", +): + log.info("Configuration path: %s", input_file) SummaryPrinter()() - with open(FLAGS.INPUT) as fin: + with open(input_file) as fin: config = json.load(fin) # ensure suffix, as in the command line help, we say "path prefix of checkpoint files" - if FLAGS.init_model is not None and not FLAGS.init_model.endswith(".pt"): - FLAGS.init_model += ".pt" - if FLAGS.restart is not None and not FLAGS.restart.endswith(".pt"): - FLAGS.restart += ".pt" + if init_model is not None and not init_model.endswith(".pt"): + init_model += ".pt" + if restart is not None and not restart.endswith(".pt"): + restart += ".pt" # update multitask config multi_task = "model_dict" in config["model"] @@ -262,26 +273,24 @@ def train(FLAGS): # update fine-tuning config finetune_links = None - if FLAGS.finetune is not None: + if finetune is not None: config["model"], finetune_links = get_finetune_rules( - FLAGS.finetune, + finetune, config["model"], - model_branch=FLAGS.model_branch, - change_model_params=FLAGS.use_pretrain_script, + model_branch=model_branch, + change_model_params=use_pretrain_script, ) # update init_model or init_frz_model config if necessary - if ( - FLAGS.init_model is not None or FLAGS.init_frz_model is not None - ) and FLAGS.use_pretrain_script: - if FLAGS.init_model is not None: - init_state_dict = torch.load(FLAGS.init_model, map_location=DEVICE) + if (init_model is not None or init_frz_model is not None) and use_pretrain_script: + if init_model is not None: + init_state_dict = torch.load(init_model, map_location=DEVICE) if "model" in init_state_dict: init_state_dict = init_state_dict["model"] config["model"] = init_state_dict["_extra_state"]["model_params"] else: config["model"] = json.loads( torch.jit.load( - FLAGS.init_frz_model, map_location=DEVICE + init_frz_model, map_location=DEVICE ).get_model_def_script() ) @@ -291,7 +300,7 @@ def train(FLAGS): # do neighbor stat min_nbor_dist = None - if not FLAGS.skip_neighbor_stat: + if not skip_neighbor_stat: log.info( "Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)" ) @@ -320,16 +329,16 @@ def train(FLAGS): ) ) - with open(FLAGS.output, "w") as fp: + with open(output, "w") as fp: json.dump(config, fp, indent=4) trainer = get_trainer( config, - FLAGS.init_model, - FLAGS.restart, - FLAGS.finetune, - FLAGS.force_load, - FLAGS.init_frz_model, + init_model, + restart, + finetune, + force_load, + init_frz_model, shared_links=shared_links, finetune_links=finetune_links, ) @@ -343,26 +352,39 @@ def train(FLAGS): trainer.run() -def freeze(FLAGS): - model = inference.Tester(FLAGS.model, head=FLAGS.head).model +def freeze( + model: str, + output: str = "frozen_model.pth", + head: Optional[str] = None, +): + model = inference.Tester(model, head=head).model model.eval() model = torch.jit.script(model) extra_files = {} torch.jit.save( model, - FLAGS.output, + output, extra_files, ) - log.info(f"Saved frozen model to {FLAGS.output}") - - -def change_bias(FLAGS): - if FLAGS.INPUT.endswith(".pt"): - old_state_dict = torch.load(FLAGS.INPUT, map_location=env.DEVICE) + log.info(f"Saved frozen model to {output}") + + +def change_bias( + input_file: str, + mode: str = "change", + bias_value: Optional[list] = None, + datafile: Optional[str] = None, + system: str = ".", + numb_batch: int = 0, + model_branch: Optional[str] = None, + output: Optional[str] = None, +): + if input_file.endswith(".pt"): + old_state_dict = torch.load(input_file, map_location=env.DEVICE) model_state_dict = copy.deepcopy(old_state_dict.get("model", old_state_dict)) model_params = model_state_dict["_extra_state"]["model_params"] - elif FLAGS.INPUT.endswith(".pth"): - old_model = torch.jit.load(FLAGS.INPUT, map_location=env.DEVICE) + elif input_file.endswith(".pth"): + old_model = torch.jit.load(input_file, map_location=env.DEVICE) model_params_string = old_model.get_model_def_script() model_params = json.loads(model_params_string) old_state_dict = old_model.state_dict() @@ -373,10 +395,7 @@ def change_bias(FLAGS): "or a frozen model with a .pth extension" ) multi_task = "model_dict" in model_params - model_branch = FLAGS.model_branch - bias_adjust_mode = ( - "change-by-statistic" if FLAGS.mode == "change" else "set-by-statistic" - ) + bias_adjust_mode = "change-by-statistic" if mode == "change" else "set-by-statistic" if multi_task: assert ( model_branch is not None @@ -393,24 +412,24 @@ def change_bias(FLAGS): else model_params["model_dict"][model_branch]["type_map"] ) model_to_change = model if not multi_task else model[model_branch] - if FLAGS.INPUT.endswith(".pt"): + if input_file.endswith(".pt"): wrapper = ModelWrapper(model) wrapper.load_state_dict(old_state_dict["model"]) else: # for .pth model.load_state_dict(old_state_dict) - if FLAGS.bias_value is not None: + if bias_value is not None: # use user-defined bias assert model_to_change.model_type in [ "ener" ], "User-defined bias is only available for energy model!" assert ( - len(FLAGS.bias_value) == len(type_map) + len(bias_value) == len(type_map) ), f"The number of elements in the bias should be the same as that in the type_map: {type_map}." old_bias = model_to_change.get_out_bias() bias_to_set = torch.tensor( - FLAGS.bias_value, dtype=old_bias.dtype, device=old_bias.device + bias_value, dtype=old_bias.dtype, device=old_bias.device ).view(old_bias.shape) model_to_change.set_out_bias(bias_to_set) log.info( @@ -421,11 +440,11 @@ def change_bias(FLAGS): updated_model = model_to_change else: # calculate bias on given systems - if FLAGS.datafile is not None: - with open(FLAGS.datafile) as datalist: + if datafile is not None: + with open(datafile) as datalist: all_sys = datalist.read().splitlines() else: - all_sys = expand_sys_str(FLAGS.system) + all_sys = expand_sys_str(system) data_systems = process_systems(all_sys) data_single = DpLoaderSet( data_systems, @@ -438,7 +457,7 @@ def change_bias(FLAGS): data_requirement = mock_loss.label_requirement data_requirement += training.get_additional_data_requirement(model_to_change) data_single.add_data_requirement(data_requirement) - nbatches = FLAGS.numb_batch if FLAGS.numb_batch != 0 else float("inf") + nbatches = numb_batch if numb_batch != 0 else float("inf") sampled_data = make_stat_input( data_single.systems, data_single.dataloaders, @@ -453,11 +472,9 @@ def change_bias(FLAGS): else: model[model_branch] = updated_model - if FLAGS.INPUT.endswith(".pt"): + if input_file.endswith(".pt"): output_path = ( - FLAGS.output - if FLAGS.output is not None - else FLAGS.INPUT.replace(".pt", "_updated.pt") + output if output is not None else input_file.replace(".pt", "_updated.pt") ) wrapper = ModelWrapper(model) if "model" in old_state_dict: @@ -470,9 +487,7 @@ def change_bias(FLAGS): else: # for .pth output_path = ( - FLAGS.output - if FLAGS.output is not None - else FLAGS.INPUT.replace(".pth", "_updated.pth") + output if output is not None else input_file.replace(".pth", "_updated.pth") ) model = torch.jit.script(model) torch.jit.save( @@ -499,7 +514,18 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None): log.info("DeePMD version: %s", __version__) if FLAGS.command == "train": - train(FLAGS) + train( + input_file=FLAGS.INPUT, + init_model=FLAGS.init_model, + restart=FLAGS.restart, + finetune=FLAGS.finetune, + init_frz_model=FLAGS.init_frz_model, + model_branch=FLAGS.model_branch, + skip_neighbor_stat=FLAGS.skip_neighbor_stat, + use_pretrain_script=FLAGS.use_pretrain_script, + force_load=FLAGS.force_load, + output=FLAGS.output, + ) elif FLAGS.command == "freeze": if Path(FLAGS.checkpoint_folder).is_dir(): checkpoint_path = Path(FLAGS.checkpoint_folder) @@ -508,9 +534,18 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None): else: FLAGS.model = FLAGS.checkpoint_folder FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth")) - freeze(FLAGS) + freeze(model=FLAGS.model, output=FLAGS.output, head=FLAGS.head) elif FLAGS.command == "change-bias": - change_bias(FLAGS) + change_bias( + input_file=FLAGS.INPUT, + mode=FLAGS.mode, + bias_value=FLAGS.bias_value, + datafile=FLAGS.datafile, + system=FLAGS.system, + numb_batch=FLAGS.numb_batch, + model_branch=FLAGS.model_branch, + output=FLAGS.output, + ) else: raise RuntimeError(f"Invalid command {FLAGS.command}!") diff --git a/source/tests/pt/model/test_deeppot.py b/source/tests/pt/model/test_deeppot.py index 8917c62cce..7f530b0a5e 100644 --- a/source/tests/pt/model/test_deeppot.py +++ b/source/tests/pt/model/test_deeppot.py @@ -2,9 +2,6 @@ import json import os import unittest -from argparse import ( - Namespace, -) from copy import ( deepcopy, ) @@ -123,12 +120,11 @@ class TestDeepPotFrozen(TestDeepPot): def setUp(self): super().setUp() frozen_model = "frozen_model.pth" - ns = Namespace( + freeze( model=self.model, output=frozen_model, head=None, ) - freeze(ns) self.model = frozen_model # Note: this can not actually disable cuda device to be used diff --git a/source/tests/pt/test_init_frz_model.py b/source/tests/pt/test_init_frz_model.py index 1cbc1b29b6..69c738d6bd 100644 --- a/source/tests/pt/test_init_frz_model.py +++ b/source/tests/pt/test_init_frz_model.py @@ -4,9 +4,6 @@ import shutil import tempfile import unittest -from argparse import ( - Namespace, -) from copy import ( deepcopy, ) @@ -70,12 +67,11 @@ def setUp(self): if imodel in [0, 1]: trainer.run() - ns = Namespace( + freeze( model="model.pt", output=frozen_model, head=None, ) - freeze(ns) self.models.append(frozen_model) def test_dp_test(self):