Skip to content

Commit

Permalink
ppl is not exact yet
Browse files Browse the repository at this point in the history
  • Loading branch information
anthdr committed Jan 15, 2025
1 parent 8c33e38 commit 9433f72
Showing 1 changed file with 53 additions and 32 deletions.
85 changes: 53 additions & 32 deletions eole/bin/tools/LM_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,18 @@
from eole.inputters.dynamic_iterator import build_dynamic_dataset_iter
from eole.utils.loss import LossCompute
from eole.constants import DefaultTokens, CorpusTask
from eole.transforms import get_transforms_cls
from eole.transforms import get_transforms_cls, make_transforms
from eole.models.model import BaseModel
from eole.decoders.ensemble import load_test_model as ensemble_load_test_model

from argparse import ArgumentParser
from eole.bin import BaseBin, register_bin
from eole.config.cli import add_model
from eole.config import get_non_default_values
from eole.config.run import PredictConfig

"""
This script scores all sentences of a file using dynamic data.
For this purpose we use the same pipeline as the validation of a file
Below is an example of settings of a config.yaml file
model: lm-de.news2021_step_100000.pt
src: newstest2014-ref.de
tgt: newstest2014-ref.de
transforms: [onmt_tokenize]
batch_size: 16
gpu: 0
src_subword_type: bpe
src_subword_model: subwords.en_de.bpe
src_eoletok_kwargs: '{"mode": "aggressive"}'
tgt_subword_type: bpe
tgt_subword_model: subwords.en_de.bpe
tgt_eoletok_kwargs: '{"mode": "aggressive"}'
Output is the data and tab separated score
use the -output setting for preds + scores
Expand Down Expand Up @@ -62,39 +49,57 @@ def run(cls, args):
config = {}
_parser = ArgumentParser()
add_model(_parser, PredictConfig)
"""
defaults = vars(_parser.parse_args([]))
stuff_to_update = get_non_default_values(args, defaults)
config.update(stuff_to_update)
"""
config = PredictConfig(**config)
init_logger(config.log_file)
set_random_seed(config.seed, False)
ppl_file = codecs.open(config.output + ".ppl", "w+", "utf-8")

# no tensor_parallel support
device = torch.device("cuda", config.gpu_ranks[0]) if len(config.gpu_ranks) > 0 else torch.device("cpu")
device = (
torch.device("cuda", config.gpu_ranks[0])
if len(config.gpu_ranks) > 0
else torch.device("cpu")
)
if len(config.gpu_ranks) > 1:
logger.warning(f"gpu_ranks is {str(config.gpu_ranks)} but only the first one will be used.")
logger.warning(
f"gpu_ranks is {str(config.gpu_ranks)} but only the first one will be used."
)

load_test_model = ensemble_load_test_model if len(config.model_path) > 1 else BaseModel.load_test_model
vocabs, model, model_opt = load_test_model(config, 0)

vocabs, model, model_opt = config.model.model_class.load_test_model(config)
pad_token = vocabs["specials"].get("pad_token", DefaultTokens.PAD)
padding_idx = vocabs["tgt"][pad_token]
criterion = torch.nn.CrossEntropyLoss(ignore_index=padding_idx, reduction="none")
padding_idx = vocabs['tgt'].tokens_to_ids[pad_token]

criterion = torch.nn.CrossEntropyLoss(
ignore_index=padding_idx, reduction="none"
)
valid_loss = LossCompute(
criterion,
model.generator,
tgt_shift_index=0,
lambda_coverage=model_opt.lambda_coverage,
lambda_align=model_opt.lambda_align,
lambda_coverage=model_opt.decoder.lambda_coverage,
lambda_align=model_opt.decoder.lambda_align,
vocabs=vocabs
)
valid_loss.to(device)

transforms_cls = get_transforms_cls(config._all_transform)
transforms_cls = make_transforms(config, transforms_cls, vocabs)

if config.tgt is None:
config.tgt = config.src

infer_iter = build_dynamic_dataset_iter(
args,
config,
transforms_cls,
vocabs,
task=CorpusTask.INFER,
device_id=config.gpu,
device_id=device.index
)

model.to(device)
Expand All @@ -110,14 +115,14 @@ def run(cls, args):
src = batch["src"]
src_len = batch["srclen"]
# print(batch)
outputs, attns = model(src, None, src_len, with_align=False)
outputs, attns, _ = model(src, None, src_len, with_align=False)
# Compute and retrieve the loss for EACH sentence
loss, _ = valid_loss(batch, outputs, attns)
loss, _, _ = valid_loss(batch, outputs, attns)
loss = loss.view(batch_size, -1) # (B, T)
losspertoken = loss.sum(1) / batch["tgt"][:, 1:, 0].ne(padding_idx).sum(1)
losspertoken = loss.sum(1) / batch["tgt"][:, 1:].ne(padding_idx).sum(1)
ppl = torch.exp(losspertoken)
cumul_loss += loss.sum().item()
cumul_length += batch["tgt"][:, 1:, 0].ne(padding_idx).sum().cpu()
cumul_length += batch["tgt"][:, 1:].ne(padding_idx).sum().cpu()
# Now we need to rearrange the batch of ppl
# in the original order with indices
sent_ppl_orig = ppl.gather(
Expand All @@ -133,8 +138,24 @@ def run(cls, args):
for j in range(batch_size):
ppl_file.write(str(sent_ppl_orig[j].item()) + "\n")
logger.info(
"Loss: %.2f Tokens: %d Corpus PPL: %.2f" % (cumul_loss, cumul_length, np.exp(cumul_loss / cumul_length))
"Loss: %.2f Tokens: %d Corpus PPL: %.2f"
% (cumul_loss / cumul_length.item(), cumul_length, np.exp(cumul_loss / cumul_length))
)
ppl_file.close()

os.system('paste "' + config.src + '" "' + config.output + '".ppl > "' + config.output + '"')
os.system(
'paste "'
+ config.src
+ '" "'
+ config.output
+ '".ppl > "'
+ config.output
+ '"'
)


if __name__ == "__main__":
parser = ArgumentParser()
LMScoring.add_args(parser)
args = parser.parse_args()
LMScoring.run(args)

0 comments on commit 9433f72

Please sign in to comment.