Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore(pt): refactor the command function interface #4225

Merged
merged 3 commits into from
Oct 16, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 92 additions & 57 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,16 +239,27 @@ def get_backend_info(self) -> dict:
}


def train(FLAGS):
log.info("Configuration path: %s", FLAGS.INPUT)
def train(
input_file: str,
init_model: Optional[str],
restart: Optional[str],
finetune: Optional[str],
init_frz_model: Optional[str],
model_branch: str,
skip_neighbor_stat: bool = False,
use_pretrain_script: bool = False,
force_load: bool = False,
output: str = "out.json",
):
iProzd marked this conversation as resolved.
Show resolved Hide resolved
log.info("Configuration path: %s", input_file)
SummaryPrinter()()
with open(FLAGS.INPUT) as fin:
with open(input_file) as fin:
config = json.load(fin)
# ensure suffix, as in the command line help, we say "path prefix of checkpoint files"
if FLAGS.init_model is not None and not FLAGS.init_model.endswith(".pt"):
FLAGS.init_model += ".pt"
if FLAGS.restart is not None and not FLAGS.restart.endswith(".pt"):
FLAGS.restart += ".pt"
if init_model is not None and not init_model.endswith(".pt"):
init_model += ".pt"
if restart is not None and not restart.endswith(".pt"):
restart += ".pt"

# update multitask config
multi_task = "model_dict" in config["model"]
Expand All @@ -262,26 +273,24 @@ def train(FLAGS):

# update fine-tuning config
finetune_links = None
if FLAGS.finetune is not None:
if finetune is not None:
config["model"], finetune_links = get_finetune_rules(
FLAGS.finetune,
finetune,
config["model"],
model_branch=FLAGS.model_branch,
change_model_params=FLAGS.use_pretrain_script,
model_branch=model_branch,
change_model_params=use_pretrain_script,
)
# update init_model or init_frz_model config if necessary
if (
FLAGS.init_model is not None or FLAGS.init_frz_model is not None
) and FLAGS.use_pretrain_script:
if FLAGS.init_model is not None:
init_state_dict = torch.load(FLAGS.init_model, map_location=DEVICE)
if (init_model is not None or init_frz_model is not None) and use_pretrain_script:
if init_model is not None:
init_state_dict = torch.load(init_model, map_location=DEVICE)
if "model" in init_state_dict:
init_state_dict = init_state_dict["model"]
config["model"] = init_state_dict["_extra_state"]["model_params"]
else:
config["model"] = json.loads(
torch.jit.load(
FLAGS.init_frz_model, map_location=DEVICE
init_frz_model, map_location=DEVICE
).get_model_def_script()
)

Expand All @@ -291,7 +300,7 @@ def train(FLAGS):

# do neighbor stat
min_nbor_dist = None
if not FLAGS.skip_neighbor_stat:
if not skip_neighbor_stat:
log.info(
"Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)"
)
Expand Down Expand Up @@ -320,16 +329,16 @@ def train(FLAGS):
)
)

with open(FLAGS.output, "w") as fp:
with open(output, "w") as fp:
json.dump(config, fp, indent=4)

trainer = get_trainer(
config,
FLAGS.init_model,
FLAGS.restart,
FLAGS.finetune,
FLAGS.force_load,
FLAGS.init_frz_model,
init_model,
restart,
finetune,
force_load,
init_frz_model,
shared_links=shared_links,
finetune_links=finetune_links,
)
Expand All @@ -343,26 +352,39 @@ def train(FLAGS):
trainer.run()


def freeze(FLAGS):
model = inference.Tester(FLAGS.model, head=FLAGS.head).model
def freeze(
model: str,
output: str = "frozen_model.pth",
head: Optional[str] = None,
):
model = inference.Tester(model, head=head).model
model.eval()
model = torch.jit.script(model)
extra_files = {}
torch.jit.save(
model,
FLAGS.output,
output,
extra_files,
)
log.info(f"Saved frozen model to {FLAGS.output}")


def change_bias(FLAGS):
if FLAGS.INPUT.endswith(".pt"):
old_state_dict = torch.load(FLAGS.INPUT, map_location=env.DEVICE)
log.info(f"Saved frozen model to {output}")


def change_bias(
input_file: str,
mode: str = "change",
bias_value: Optional[list] = None,
datafile: Optional[str] = None,
system: str = ".",
numb_batch: int = 0,
model_branch: Optional[str] = None,
output: Optional[str] = None,
):
iProzd marked this conversation as resolved.
Show resolved Hide resolved
if input_file.endswith(".pt"):
old_state_dict = torch.load(input_file, map_location=env.DEVICE)
model_state_dict = copy.deepcopy(old_state_dict.get("model", old_state_dict))
model_params = model_state_dict["_extra_state"]["model_params"]
elif FLAGS.INPUT.endswith(".pth"):
old_model = torch.jit.load(FLAGS.INPUT, map_location=env.DEVICE)
elif input_file.endswith(".pth"):
old_model = torch.jit.load(input_file, map_location=env.DEVICE)
model_params_string = old_model.get_model_def_script()
model_params = json.loads(model_params_string)
old_state_dict = old_model.state_dict()
Expand All @@ -373,10 +395,7 @@ def change_bias(FLAGS):
"or a frozen model with a .pth extension"
)
multi_task = "model_dict" in model_params
model_branch = FLAGS.model_branch
bias_adjust_mode = (
"change-by-statistic" if FLAGS.mode == "change" else "set-by-statistic"
)
bias_adjust_mode = "change-by-statistic" if mode == "change" else "set-by-statistic"
if multi_task:
assert (
model_branch is not None
Expand All @@ -393,24 +412,24 @@ def change_bias(FLAGS):
else model_params["model_dict"][model_branch]["type_map"]
)
model_to_change = model if not multi_task else model[model_branch]
if FLAGS.INPUT.endswith(".pt"):
if input_file.endswith(".pt"):
wrapper = ModelWrapper(model)
wrapper.load_state_dict(old_state_dict["model"])
else:
# for .pth
model.load_state_dict(old_state_dict)

if FLAGS.bias_value is not None:
if bias_value is not None:
# use user-defined bias
assert model_to_change.model_type in [
"ener"
], "User-defined bias is only available for energy model!"
assert (
len(FLAGS.bias_value) == len(type_map)
len(bias_value) == len(type_map)
), f"The number of elements in the bias should be the same as that in the type_map: {type_map}."
old_bias = model_to_change.get_out_bias()
bias_to_set = torch.tensor(
FLAGS.bias_value, dtype=old_bias.dtype, device=old_bias.device
bias_value, dtype=old_bias.dtype, device=old_bias.device
).view(old_bias.shape)
model_to_change.set_out_bias(bias_to_set)
log.info(
Expand All @@ -421,11 +440,11 @@ def change_bias(FLAGS):
updated_model = model_to_change
else:
# calculate bias on given systems
if FLAGS.datafile is not None:
with open(FLAGS.datafile) as datalist:
if datafile is not None:
with open(datafile) as datalist:
all_sys = datalist.read().splitlines()
else:
all_sys = expand_sys_str(FLAGS.system)
all_sys = expand_sys_str(system)
data_systems = process_systems(all_sys)
data_single = DpLoaderSet(
data_systems,
Expand All @@ -438,7 +457,7 @@ def change_bias(FLAGS):
data_requirement = mock_loss.label_requirement
data_requirement += training.get_additional_data_requirement(model_to_change)
data_single.add_data_requirement(data_requirement)
nbatches = FLAGS.numb_batch if FLAGS.numb_batch != 0 else float("inf")
nbatches = numb_batch if numb_batch != 0 else float("inf")
sampled_data = make_stat_input(
data_single.systems,
data_single.dataloaders,
Expand All @@ -453,11 +472,9 @@ def change_bias(FLAGS):
else:
model[model_branch] = updated_model

if FLAGS.INPUT.endswith(".pt"):
if input_file.endswith(".pt"):
output_path = (
FLAGS.output
if FLAGS.output is not None
else FLAGS.INPUT.replace(".pt", "_updated.pt")
output if output is not None else input_file.replace(".pt", "_updated.pt")
)
wrapper = ModelWrapper(model)
if "model" in old_state_dict:
Expand All @@ -470,9 +487,7 @@ def change_bias(FLAGS):
else:
# for .pth
output_path = (
FLAGS.output
if FLAGS.output is not None
else FLAGS.INPUT.replace(".pth", "_updated.pth")
output if output is not None else input_file.replace(".pth", "_updated.pth")
)
model = torch.jit.script(model)
torch.jit.save(
Expand All @@ -499,7 +514,18 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None):
log.info("DeePMD version: %s", __version__)

if FLAGS.command == "train":
train(FLAGS)
train(
input_file=FLAGS.INPUT,
init_model=FLAGS.init_model,
restart=FLAGS.restart,
finetune=FLAGS.finetune,
init_frz_model=FLAGS.init_frz_model,
model_branch=FLAGS.model_branch,
skip_neighbor_stat=FLAGS.skip_neighbor_stat,
use_pretrain_script=FLAGS.use_pretrain_script,
force_load=FLAGS.force_load,
output=FLAGS.output,
)
elif FLAGS.command == "freeze":
if Path(FLAGS.checkpoint_folder).is_dir():
checkpoint_path = Path(FLAGS.checkpoint_folder)
Expand All @@ -508,9 +534,18 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None):
else:
FLAGS.model = FLAGS.checkpoint_folder
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth"))
freeze(FLAGS)
freeze(model=FLAGS.model, output=FLAGS.output, head=FLAGS.head)
elif FLAGS.command == "change-bias":
change_bias(FLAGS)
change_bias(
input_file=FLAGS.INPUT,
mode=FLAGS.mode,
bias_value=FLAGS.bias_value,
datafile=FLAGS.datafile,
system=FLAGS.system,
numb_batch=FLAGS.numb_batch,
model_branch=FLAGS.model_branch,
output=FLAGS.output,
)
else:
raise RuntimeError(f"Invalid command {FLAGS.command}!")

Expand Down
Loading