diff --git a/eole/bin/tools/LM_scoring.py b/eole/bin/tools/LM_scoring.py index bd4f454d..f9b19f14 100644 --- a/eole/bin/tools/LM_scoring.py +++ b/eole/bin/tools/LM_scoring.py @@ -10,7 +10,6 @@ from eole.constants import DefaultTokens, CorpusTask from eole.transforms import get_transforms_cls, make_transforms from eole.models.model import BaseModel -from eole.decoders.ensemble import load_test_model as ensemble_load_test_model from argparse import ArgumentParser from eole.bin import BaseBin, register_bin @@ -65,8 +64,7 @@ def run(cls, args): if len(config.gpu_ranks) > 1: logger.warning(f"gpu_ranks is {str(config.gpu_ranks)} but only the first one will be used.") - load_test_model = ensemble_load_test_model if len(config.model_path) > 1 else BaseModel.load_test_model - vocabs, model, model_opt = load_test_model(config, 0) + vocabs, model, model_opt = BaseModel.load_test_model(config, 0) pad_token = vocabs["specials"].get("pad_token", DefaultTokens.PAD) padding_idx = vocabs["tgt"].tokens_to_ids[pad_token] criterion = torch.nn.CrossEntropyLoss(ignore_index=padding_idx, reduction="none")