Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt lm scoring #1

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 28 additions & 43 deletions eole/bin/tools/LM_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,26 @@
from eole.inputters.dynamic_iterator import build_dynamic_dataset_iter
from eole.utils.loss import LossCompute
from eole.constants import DefaultTokens, CorpusTask
from eole.transforms import get_transforms_cls
from eole.transforms import get_transforms_cls, make_transforms
from eole.models.model import BaseModel

from argparse import ArgumentParser
from eole.bin import BaseBin, register_bin
from eole.config.cli import add_model
from eole.config import get_non_default_values
from eole.config.run import PredictConfig

"""
This script scores all sentences of a file using dynamic data.
For this purpose we use the same pipeline as the validation of a file
Below is an example of settings of a config.yaml file

model: lm-de.news2021_step_100000.pt
src: newstest2014-ref.de
tgt: newstest2014-ref.de
transforms: [onmt_tokenize]
batch_size: 16
gpu: 0
src_subword_type: bpe
src_subword_model: subwords.en_de.bpe
src_eoletok_kwargs: '{"mode": "aggressive"}'
tgt_subword_type: bpe
tgt_subword_model: subwords.en_de.bpe
tgt_eoletok_kwargs: '{"mode": "aggressive"}'
verbose: false
world_size: 1
gpu_ranks: [0]
# use symlinks to last saved step
model_path: data/wikitext/wikitext-103-raw-v1/run/model-lm
src: data/wikitext/wikitext-103-raw-v1/lm_input.txt
output: data/wikitext/wikitext-103-raw-v1/lm_pred.txt

Output is the data and tab separated score
use the -output setting for preds + scores
Expand All @@ -44,13 +39,7 @@
class LMScoring(BaseBin):
@classmethod
def add_args(cls, parser):
parser.add_argument(
"-config",
"--config",
"-c",
required=False,
help="Path of main YAML config file.",
)
parser.add_argument("-config", "--config", "-c", required=False, help="Path of main YAML config file.")

@classmethod
def run(cls, args):
Expand All @@ -62,9 +51,6 @@ def run(cls, args):
config = {}
_parser = ArgumentParser()
add_model(_parser, PredictConfig)
defaults = vars(_parser.parse_args([]))
stuff_to_update = get_non_default_values(args, defaults)
config.update(stuff_to_update)
config = PredictConfig(**config)
init_logger(config.log_file)
set_random_seed(config.seed, False)
Expand All @@ -75,26 +61,27 @@ def run(cls, args):
if len(config.gpu_ranks) > 1:
logger.warning(f"gpu_ranks is {str(config.gpu_ranks)} but only the first one will be used.")

vocabs, model, model_opt = config.model.model_class.load_test_model(config)
vocabs, model, model_opt = BaseModel.load_test_model(config, device.index)
pad_token = vocabs["specials"].get("pad_token", DefaultTokens.PAD)
padding_idx = vocabs["tgt"][pad_token]
padding_idx = vocabs["tgt"].tokens_to_ids[pad_token]
criterion = torch.nn.CrossEntropyLoss(ignore_index=padding_idx, reduction="none")
valid_loss = LossCompute(
criterion,
model.generator,
tgt_shift_index=0,
lambda_coverage=model_opt.lambda_coverage,
lambda_align=model_opt.lambda_align,
lambda_coverage=model_opt.decoder.lambda_coverage,
lambda_align=model_opt.decoder.lambda_align,
vocabs=vocabs,
)
valid_loss.to(device)

transforms_cls = get_transforms_cls(config._all_transform)
transforms_cls = make_transforms(config, transforms_cls, vocabs)

# if tgt is not precised in the inference config file, used from src
if config.tgt is None:
config.tgt = config.src
infer_iter = build_dynamic_dataset_iter(
args,
transforms_cls,
vocabs,
task=CorpusTask.INFER,
device_id=config.gpu,
config, transforms_cls, vocabs, task=CorpusTask.INFER, device_id=device.index
)

model.to(device)
Expand All @@ -110,30 +97,28 @@ def run(cls, args):
src = batch["src"]
src_len = batch["srclen"]
# print(batch)
outputs, attns = model(src, None, src_len, with_align=False)
outputs, attns, _ = model(src, None, src_len, with_align=False)
# Compute and retrieve the loss for EACH sentence
loss, _ = valid_loss(batch, outputs, attns)
loss, _, _ = valid_loss(batch, outputs, attns)
loss = loss.view(batch_size, -1) # (B, T)
losspertoken = loss.sum(1) / batch["tgt"][:, 1:, 0].ne(padding_idx).sum(1)
losspertoken = loss.sum(1) / batch["tgt"][:, 1:].ne(padding_idx).sum(1)
ppl = torch.exp(losspertoken)
cumul_loss += loss.sum().item()
cumul_length += batch["tgt"][:, 1:, 0].ne(padding_idx).sum().cpu()
cumul_length += batch["tgt"][:, 1:].ne(padding_idx).sum().cpu()
# Now we need to rearrange the batch of ppl
# in the original order with indices
sent_ppl_orig = ppl.gather(
0,
torch.tensor(
sorted(
range(len(batch["cid_line_number"])),
key=lambda k: batch["cid_line_number"][k],
),
sorted(range(len(batch["cid_line_number"])), key=lambda k: batch["cid_line_number"][k]),
device=ppl.device,
),
)
for j in range(batch_size):
ppl_file.write(str(sent_ppl_orig[j].item()) + "\n")
logger.info(
"Loss: %.2f Tokens: %d Corpus PPL: %.2f" % (cumul_loss, cumul_length, np.exp(cumul_loss / cumul_length))
"Loss: %.2f Tokens: %d Corpus PPL: %.2f"
% (cumul_loss / cumul_length.item(), cumul_length, np.exp(cumul_loss / cumul_length))
)
ppl_file.close()

Expand Down
Loading