From 1a90719bf9b1802fbac47917b0689b7e37b7ebd7 Mon Sep 17 00:00:00 2001 From: anthdr Date: Thu, 16 Jan 2025 11:46:39 +0100 Subject: [PATCH] cleaning code --- eole/bin/tools/LM_scoring.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/eole/bin/tools/LM_scoring.py b/eole/bin/tools/LM_scoring.py index bd4f454d..4c27761b 100644 --- a/eole/bin/tools/LM_scoring.py +++ b/eole/bin/tools/LM_scoring.py @@ -43,7 +43,13 @@ class LMScoring(BaseBin): @classmethod def add_args(cls, parser): - parser.add_argument("-config", "--config", "-c", required=False, help="Path of main YAML config file.") + parser.add_argument( + "-config", + "--config", + "-c", + required=False, + help="Path of main YAML config file." + ) @classmethod def run(cls, args): @@ -65,8 +71,7 @@ def run(cls, args): if len(config.gpu_ranks) > 1: logger.warning(f"gpu_ranks is {str(config.gpu_ranks)} but only the first one will be used.") - load_test_model = ensemble_load_test_model if len(config.model_path) > 1 else BaseModel.load_test_model - vocabs, model, model_opt = load_test_model(config, 0) + vocabs, model, model_opt = BaseModel.load_test_model(config, 0) pad_token = vocabs["specials"].get("pad_token", DefaultTokens.PAD) padding_idx = vocabs["tgt"].tokens_to_ids[pad_token] criterion = torch.nn.CrossEntropyLoss(ignore_index=padding_idx, reduction="none")