From 2674c4fe29a592993f073d65fff0a3e96136e8b7 Mon Sep 17 00:00:00 2001 From: anthdr Date: Fri, 17 Jan 2025 18:26:16 +0100 Subject: [PATCH] black reformat, removed main, adapted yaml example --- eole/bin/tools/LM_scoring.py | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/eole/bin/tools/LM_scoring.py b/eole/bin/tools/LM_scoring.py index 6c09d1f4..5eb2a903 100644 --- a/eole/bin/tools/LM_scoring.py +++ b/eole/bin/tools/LM_scoring.py @@ -10,7 +10,6 @@ from eole.constants import DefaultTokens, CorpusTask from eole.transforms import get_transforms_cls, make_transforms from eole.models.model import BaseModel -from eole.decoders.ensemble import load_test_model as ensemble_load_test_model from argparse import ArgumentParser from eole.bin import BaseBin, register_bin @@ -23,9 +22,6 @@ Below is an example of settings of a config.yaml file verbose: false -n_best: 3 -top_p: 0.9 -beam_size: 10 world_size: 1 gpu_ranks: [0] # use symlinks to last saved step @@ -43,12 +39,7 @@ class LMScoring(BaseBin): @classmethod def add_args(cls, parser): - parser.add_argument( - "-config", - "--config", - "-c", - required=False, - help="Path of main YAML config file.") + parser.add_argument("-config", "--config", "-c", required=False, help="Path of main YAML config file.") @classmethod def run(cls, args): @@ -70,8 +61,7 @@ def run(cls, args): if len(config.gpu_ranks) > 1: logger.warning(f"gpu_ranks is {str(config.gpu_ranks)} but only the first one will be used.") - load_test_model = ensemble_load_test_model if len(config.model_path) > 1 else BaseModel.load_test_model - vocabs, model, model_opt = load_test_model(config, 0) + vocabs, model, model_opt = BaseModel.load_test_model(config, device.index) pad_token = vocabs["specials"].get("pad_token", DefaultTokens.PAD) padding_idx = vocabs["tgt"].tokens_to_ids[pad_token] criterion = torch.nn.CrossEntropyLoss(ignore_index=padding_idx, reduction="none") @@ -133,10 +123,3 @@ def run(cls, args): ppl_file.close() os.system('paste "' + config.src + '" "' + config.output + '".ppl > "' + config.output + '"') - - -if __name__ == "__main__": - parser = ArgumentParser() - LMScoring.add_args(parser) - args = parser.parse_args() - LMScoring.run(args)