Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compile and Ctranslate2 support #161

Merged
merged 19 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -299,17 +299,16 @@ jobs:
-model_type decoder \
-input_file /tmp/src-test.txt \
-inference_config_file eole/tests/data/inference-engine_py.yaml \
-inference_mode py \
-out /tmp/inference_engine_lm_py_outputs
- name: Test ct2-LM inference engine
run: |
head eole/tests/data/src-test.txt > /tmp/src-test.txt
python eole/tests/test_inference_engines.py \
-model eole/tests/test_model_lm_ct2 \
-model eole/tests \
-model_type decoder \
-input_file /tmp/src-test.txt \
-inference_config_file eole/tests/data/inference-engine_py.yaml \
-inference_mode ct2 \
-engine ct2 \
-out /tmp/inference_engine_lm_py_outputs
- name: Test py-SEQ2SEQ inference engine
run: |
Expand All @@ -319,7 +318,6 @@ jobs:
-model_type encoder_decoder \
-input_file /tmp/src-test.txt \
-inference_config_file eole/tests/data/inference-engine_py.yaml \
-inference_mode py \
-out /tmp/inference_engine_lm_py_outputs
- name: Test embeddings_to_torch tool
run: |
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ Depending on your needs, you can add various flags:
#### Requirements

- Python >= 3.10
- PyTorch >= 2.3 < 2.4
- PyTorch >= 2.5 < 2.6

#### Installation from Source

Expand Down
9 changes: 7 additions & 2 deletions eole/bin/run/predict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from eole.inference_engine import InferenceEnginePY
from eole.inference_engine import InferenceEnginePY, InferenceEngineCT2

from argparse import ArgumentParser
from eole.utils.misc import use_gpu, set_random_seed
Expand All @@ -14,7 +14,12 @@
def predict(config):
set_random_seed(config.seed, use_gpu(config))

engine = InferenceEnginePY(config)
if config.engine == "eole":
engine = InferenceEnginePY(config)
elif config.engine == "ct2":
engine = InferenceEngineCT2(config, "decoder")
else:
raise ValueError("You need to use --engine with 'eole' or 'ct2'")
_, _, _ = engine.infer_file()
engine.terminate()

Expand Down
3 changes: 3 additions & 0 deletions eole/config/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ class RunningConfig(DistributedConfig):
"fp32 to force slow fp16 model on gtx1080, "
"int8 to enable pytorch native 8-bit quantization (cpu only).",
)
torch_compile: bool = Field(
default=False, description="Use torch.compile with dynamic=True."
)

@field_validator("compute_dtype", mode="before")
@classmethod
Expand Down
2 changes: 1 addition & 1 deletion eole/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class RotaryPositionConfig(Config):
default=True,
description="Interleave the head dimensions when rotary embeddings are applied. "
"Otherwise the head dimensions are sliced in half. "
"(True=default Llama from Meta (original), "
"(True= Llama from Meta (original), "
"False= used by all HuggingFace models)",
)
rotary_theta: int = Field(
Expand Down
4 changes: 4 additions & 0 deletions eole/config/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ class PredictConfig(
default="pred.txt",
description="Path to output the predictions (each line will be the decoded sequence).",
)
engine: str = Field(
default="eole",
description="engine to run inference: eole or ct2",
)

@model_validator(mode="after")
def _validate_predict_config(self):
Expand Down
179 changes: 142 additions & 37 deletions eole/inference_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import torch
import json
import os
import codecs
from eole.constants import CorpusTask, DefaultTokens, ModelType
from eole.inputters.dynamic_iterator import build_dynamic_dataset_iter
from eole.utils.distributed import ErrorHandler
Expand Down Expand Up @@ -39,6 +42,27 @@ def infer_file(self):
scores, estims, preds = self._predict(infer_iter)
else:
scores, estims, preds = self.infer_file_parallel()

out_file = codecs.open(self.config.output, "w+", "utf-8")

flatten_preds = [text for sublist in preds for text in sublist]
flatten_scores = [score for sublist in scores for score in sublist]
if estims is not None:
flatten_estims = [estim for sublist in estims for estim in sublist]
else:
flatten_estims = [1.0 for sublist in scores for score in sublist]

if self.config.with_score:
out_all = [
pred + "\t" + str(score) + "\t" + str(estim)
for (pred, score, estim) in zip(
flatten_preds, flatten_scores, flatten_estims
)
]
out_file.write("\n".join(out_all) + "\n")
else:
out_file.write("\n".join(flatten_preds) + "\n")

return scores, estims, preds

def infer_list(self, src, settings={}):
Expand Down Expand Up @@ -158,6 +182,7 @@ def __init__(self, config):
self.device_id = config.gpu_ranks[0]
else:
self.device_id = -1 # cpu

self.predictor = build_predictor(
config, self.device_id, logger=self.logger, report_score=True
)
Expand All @@ -166,6 +191,7 @@ def __init__(self, config):
self.transforms = make_transforms(config, self.transforms_cls, self.vocabs)
self.transform_pipe = TransformPipe.build_from(self.transforms.values())

@torch.inference_mode()
def _predict(self, infer_iter, settings={}):
self.predictor.update_settings(**settings)
scores, estims, preds = self.predictor._predict(
Expand Down Expand Up @@ -257,40 +283,85 @@ def __init__(self, config, model_type=None):
self.device_index = 0
self.device = "cpu"
self.transforms_cls = get_transforms_cls(self.config._all_transform)
# Build translator

ct2_config = os.path.join(
config.get_model_path() + "/ctranslate2", "config.json"
)
ct2_json = json.load(open(ct2_config, "r"))
vocabs = {}
vocabs["specials"] = {}
vocabs["specials"]["bos_token"] = ct2_json["bos_token"]
vocabs["specials"]["eos_token"] = ct2_json["eos_token"]
vocabs["specials"]["unk_token"] = ct2_json["unk_token"]
if "pad_token" in ct2_json.keys():
vocabs["specials"]["pad_token"] = ct2_json["pad_token"]
else:
vocabs["specials"]["pad_token"] = ct2_json["eos_token"]

# Build generator or translator
self.model_type = model_type
if self.model_type == ModelType.DECODER:
self.predictor = ctranslate2.Generator(
config.get_model_path(),
config.get_model_path() + "/ctranslate2",
device=self.device,
device_index=self.device_index,
)
vocab_path = os.path.join(
config.get_model_path() + "/ctranslate2", "vocabulary.json"
)
vocab = json.load(open(vocab_path, "r"))
src_vocab = pyonmttok.build_vocab_from_tokens(vocab)
config.share_vocab = True
vocabs["src"] = src_vocab
vocabs["tgt"] = src_vocab
vocabs["decoder_start_token"] = ""
else:
self.predictor = ctranslate2.Translator(
config.get_model_path(),
device=self.device,
device_index=self.device_index,
)
# Build vocab
vocab_path = config.src_subword_vocab # this is not super clean
with open(vocab_path, "r") as f:
vocab = json.load(f)
vocabs = {}
src_vocab = pyonmttok.build_vocab_from_tokens(vocab)
vocabs["src"] = src_vocab
vocabs["tgt"] = src_vocab
vocabs["decoder_start_token"] = "<s>"
# TODO: this should be loaded from model config
vocabs["specials"] = {
"bos_token": DefaultTokens.BOS,
"pad_token": DefaultTokens.PAD,
"eos_token": DefaultTokens.EOS,
"unk_token": DefaultTokens.UNK,
}
vocabs["decoder_start_token"] = ct2_json["decoder_start_token"]
if os.path.exists(
config.get_model_path() + "/ctranslate2", "shared_vocabulary.json"
):
vocab = json.load(
open(
config.get_model_path() + "/ctranslate2",
"shared_vocabulary.json",
"r",
)
)
src_vocab = pyonmttok.build_vocab_from_tokens(vocab)
config.share_vocab = True
vocabs["src"] = src_vocab
vocabs["tgt"] = src_vocab

else:
vocab_src = json.load(
open(
config.get_model_path() + "/ctranslate2",
"source_vocabulary.json",
"r",
)
)
src_vocab = pyonmttok.build_vocab_from_tokens(vocab_src)
vocab_tgt = json.load(
open(
config.get_model_path() + "/ctranslate2",
"target_vocabulary.json",
"r",
)
)
tgt_vocab = pyonmttok.build_vocab_from_tokens(vocab_tgt)
config.share_vocab = False
vocabs["src"] = src_vocab
vocabs["tgt"] = tgt_vocab

self.vocabs = vocabs
# Build transform pipe
transforms = make_transforms(config, self.transforms_cls, self.vocabs)
self.transforms = TransformPipe.build_from(transforms.values())
self.transforms = make_transforms(config, self.transforms_cls, self.vocabs)
self.transforms_pipe = TransformPipe.build_from(self.transforms.values())

def predict_batch(self, batch, config):
input_tokens = []
Expand All @@ -305,6 +376,7 @@ def predict_batch(self, batch, config):
)
]
input_tokens.append(_input_tokens)

if self.model_type == ModelType.DECODER:
predicted_batch = self.predictor.generate_batch(
start_tokens=input_tokens,
Expand All @@ -316,14 +388,23 @@ def predict_batch(self, batch, config):
return_scores=True,
include_prompt_in_result=False,
sampling_topk=config.top_k,
sampling_topp=config.top_p,
sampling_topp=1 if config.top_p == 0 else config.top_p,
sampling_temperature=config.temperature,
)
preds = [
[self.transforms.apply_reverse(tokens) for tokens in out.sequences]
for out in predicted_batch
]
scores = [out.scores for out in predicted_batch]
if self.transforms != {}:
preds = [
[
self.transforms_pipe.apply_reverse(nbest)
for nbest in ex.sequences
]
for ex in predicted_batch
]
else:
preds = [
[" ".join(nbest) for nbest in ex.sequences]
for ex in predicted_batch
]

elif self.model_type == ModelType.ENCODER_DECODER:
predicted_batch = self.predictor.translate_batch(
input_tokens,
Expand All @@ -334,26 +415,50 @@ def predict_batch(self, batch, config):
max_decoding_length=config.max_length,
return_scores=True,
sampling_topk=config.top_k,
sampling_topp=config.top_p,
sampling_topp=1 if config.top_p == 0 else config.top_p,
sampling_temperature=config.temperature,
)
preds = [
[self.transforms.apply_reverse(tokens) for tokens in out.hypotheses]
for out in predicted_batch
]
scores = [out.scores for out in predicted_batch]
if self.transforms != {}:
preds = [
[
self.transforms_pipe.apply_reverse(nbest)
for nbest in ex.hypothesis
]
for ex in predicted_batch
]
else:
preds = [
[" ".join(nbest) for nbest in ex.sequences]
for ex in predicted_batch
]

scores = [[nbest for nbest in ex.scores] for ex in predicted_batch]
return scores, None, preds

def _predict(self, infer_iter, settings={}):
# TODO: convert settings to CT2 naming
scores = []
preds = []
predictions = {}
predictions["scores"] = []
predictions["preds"] = []
predictions["cid_line_number"] = []
for batch, bucket_idx in infer_iter:
_scores, _, _preds = self.predict_batch(batch, self.config)
scores += _scores
preds += _preds
return scores, None, preds
predictions["scores"] += _scores
predictions["preds"] += _preds
predictions["cid_line_number"] += batch["cid_line_number"]
sorted_data = sorted(
zip(
predictions["cid_line_number"],
predictions["preds"],
predictions["scores"],
)
)
sorted_predictions = {
"cid_line_number": [item[0] for item in sorted_data],
"preds": [item[1] for item in sorted_data],
"scores": [item[2] for item in sorted_data],
}
return sorted_predictions["scores"], None, sorted_predictions["preds"]

def _score(self, infer_iter):
raise NotImplementedError(
Expand Down
2 changes: 1 addition & 1 deletion eole/models/model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def load_checkpoint(model_path):
optim_path = os.path.join(model_path, "optimizer.pt")
if os.path.exists(optim_path):
checkpoint["optim"] = torch.load(
optim_path, map_location=torch.device("cpu")
optim_path, map_location=torch.device("cpu"), weights_only=True
)
else:
raise FileNotFoundError(f"{model_path} is not a directory.")
Expand Down
Loading
Loading