Skip to content

Commit

Permalink
black reformat, removed main, adapted yaml example
Browse files Browse the repository at this point in the history
  • Loading branch information
anthdr committed Jan 17, 2025
1 parent f3c000d commit 2674c4f
Showing 1 changed file with 2 additions and 19 deletions.
21 changes: 2 additions & 19 deletions eole/bin/tools/LM_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from eole.constants import DefaultTokens, CorpusTask
from eole.transforms import get_transforms_cls, make_transforms
from eole.models.model import BaseModel
from eole.decoders.ensemble import load_test_model as ensemble_load_test_model

from argparse import ArgumentParser
from eole.bin import BaseBin, register_bin
Expand All @@ -23,9 +22,6 @@
Below is an example of settings of a config.yaml file
verbose: false
n_best: 3
top_p: 0.9
beam_size: 10
world_size: 1
gpu_ranks: [0]
# use symlinks to last saved step
Expand All @@ -43,12 +39,7 @@
class LMScoring(BaseBin):
@classmethod
def add_args(cls, parser):
parser.add_argument(
"-config",
"--config",
"-c",
required=False,
help="Path of main YAML config file.")
parser.add_argument("-config", "--config", "-c", required=False, help="Path of main YAML config file.")

@classmethod
def run(cls, args):
Expand All @@ -70,8 +61,7 @@ def run(cls, args):
if len(config.gpu_ranks) > 1:
logger.warning(f"gpu_ranks is {str(config.gpu_ranks)} but only the first one will be used.")

load_test_model = ensemble_load_test_model if len(config.model_path) > 1 else BaseModel.load_test_model
vocabs, model, model_opt = load_test_model(config, 0)
vocabs, model, model_opt = BaseModel.load_test_model(config, device.index)
pad_token = vocabs["specials"].get("pad_token", DefaultTokens.PAD)
padding_idx = vocabs["tgt"].tokens_to_ids[pad_token]
criterion = torch.nn.CrossEntropyLoss(ignore_index=padding_idx, reduction="none")
Expand Down Expand Up @@ -133,10 +123,3 @@ def run(cls, args):
ppl_file.close()

os.system('paste "' + config.src + '" "' + config.output + '".ppl > "' + config.output + '"')


if __name__ == "__main__":
parser = ArgumentParser()
LMScoring.add_args(parser)
args = parser.parse_args()
LMScoring.run(args)

0 comments on commit 2674c4f

Please sign in to comment.