Skip to content

Commit

Permalink
refactor(train): move yaml saving to save_model() (#2156)
Browse files Browse the repository at this point in the history
  • Loading branch information
xingchensong authored Nov 23, 2023
1 parent 69987c3 commit eafd44a
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def save_model(model, info_dict):
rank = int(os.environ.get('RANK', 0))
tag = info_dict["tag"]
model_dir = info_dict["model_dir"]
# save ckpt
if info_dict["train_engine"] == "deepspeed":
# NOTE(xcsong): All ranks should call this API, but only rank 0
# save the general model params. see:
Expand All @@ -410,6 +411,11 @@ def save_model(model, info_dict):
# NOTE(xcsong): For torch_ddp, only rank-0 should call this.
save_model_path = os.path.join(model_dir, '{}.pt'.format(tag))
save_checkpoint(model, save_model_path, info_dict)
# save yaml
if rank == 0:
with open("{}/{}.yaml".format(model_dir, tag), 'w') as fout:
data = yaml.dump(info_dict)
fout.write(data)


def wenet_join(group_join, info_dict):
Expand Down Expand Up @@ -598,6 +604,3 @@ def log_per_epoch(writer, info_dict):
if int(os.environ.get('RANK', 0)) == 0:
writer.add_scalar('epoch/cv_loss', info_dict["cv_loss"], epoch)
writer.add_scalar('epoch/lr', info_dict["lr"], epoch)
with open("{}/{}.yaml".format(info_dict["model_dir"], epoch), 'w') as fout:
data = yaml.dump(info_dict)
fout.write(data)

0 comments on commit eafd44a

Please sign in to comment.