diff --git a/examples/nlp/language_modeling/conf/megatron_retro_inference.yaml b/examples/nlp/language_modeling/conf/megatron_retro_inference.yaml index 1b99a65f46ad9..6cde27f555277 100644 --- a/examples/nlp/language_modeling/conf/megatron_retro_inference.yaml +++ b/examples/nlp/language_modeling/conf/megatron_retro_inference.yaml @@ -3,42 +3,40 @@ inference: top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering. top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. temperature: 1.0 # sampling temperature - add_BOS: True # add the bos token at the begining of the prompt + add_BOS: False # add the bos token at the begining of the prompt tokens_to_generate: 30 # The minimum length of the sequence to be generated. all_probs: False # whether return the log prob for all the tokens in vocab repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty. min_tokens_to_generate: 0 # The minimum length of the sequence to be generated. compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False - + end_strings: ["<|endoftext|>"] # generation will stop when one of these tokens is generated + # RETRO-specific arguments + retro_inference: + retro_gpt_retrieved_length: 128 + retro_num_neighbors: 2 + ft_neighbours: 0 + reuse_top: False trainer: devices: 1 num_nodes: 1 accelerator: gpu logger: False # logger provided by exp_manager - precision: 16 # 16, 32, or bf16 - -inference_batch_size: 2 + precision: 32 # 16, 32, or bf16 + use_distributed_sampler: False + tensor_model_parallel_size: -1 pipeline_model_parallel_size: -1 pipeline_model_parallel_split_rank: -1 # used for encoder and decoder model (0 for others) -retro_model_file: null # RETRO nemo file path +megatron_amp_O2: False # Enable O2-level automatic mixed precision to save memory -use_predict_method: False # whether to use the predict method +retro_model_file: null # Retro nemo file path +checkpoint_dir: null # checkpoint file dir. This is used to load the PTL checkpoint generated during the Retro training +checkpoint_name: null # PTL checkpoint file name, only used for PTL checkpoint loading +hparams_file: null # model configuration file, only used for PTL checkpoint loading -prompts: # prompts for RETRO model inference - - "hello," - - "good morning," - - "good afternoon," - - "good evening," - -########### Faiss service parameters ######## -retrieval_service: - strategy: RetroModelTextGenerationStrategy # choose customized inference strategy - neighbors: 4 - frequent_query: False # for the current token generation, frequently update the retrieval context. If false, update it every 64 tokens - pad_tokens: True # pad the tokens at the beginning to make it minimum of 64 tokens for retrieving at least once - store_retrieved: False # whether store the retrieved documents, so it can be checked - combo_service: - service_ip: '0.0.0.0' - service_port: 17181 \ No newline at end of file +# RETRO inference +prompt: "sample prompt" +neighbors: + - "neighbor text 1" + - "neighbor text 2" \ No newline at end of file diff --git a/examples/nlp/language_modeling/conf/megatron_retro_inference_legacy.yaml b/examples/nlp/language_modeling/conf/megatron_retro_inference_legacy.yaml new file mode 100644 index 0000000000000..83d88339b30b6 --- /dev/null +++ b/examples/nlp/language_modeling/conf/megatron_retro_inference_legacy.yaml @@ -0,0 +1,46 @@ +# (This inferencing script for native NeMo RETRO will be soon deprecated. For new inferencing script for mcore RETRO, see ./megatron_retro_inference.yaml) + +inference: + greedy: False # Whether or not to use sampling ; use greedy decoding otherwise + top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature: 1.0 # sampling temperature + add_BOS: True # add the bos token at the begining of the prompt + tokens_to_generate: 30 # The minimum length of the sequence to be generated. + all_probs: False # whether return the log prob for all the tokens in vocab + repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty. + min_tokens_to_generate: 0 # The minimum length of the sequence to be generated. + compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False + + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + logger: False # logger provided by exp_manager + precision: 16 # 16, 32, or bf16 + +inference_batch_size: 2 +tensor_model_parallel_size: -1 +pipeline_model_parallel_size: -1 +pipeline_model_parallel_split_rank: -1 # used for encoder and decoder model (0 for others) +retro_model_file: null # RETRO nemo file path + +use_predict_method: False # whether to use the predict method + +prompts: # prompts for RETRO model inference + - "hello," + - "good morning," + - "good afternoon," + - "good evening," + +########### Faiss service parameters ######## +retrieval_service: + strategy: RetroModelTextGenerationStrategy # choose customized inference strategy + neighbors: 4 + frequent_query: False # for the current token generation, frequently update the retrieval context. If false, update it every 64 tokens + pad_tokens: True # pad the tokens at the beginning to make it minimum of 64 tokens for retrieving at least once + store_retrieved: False # whether store the retrieved documents, so it can be checked + combo_service: + service_ip: '0.0.0.0' + service_port: 17181 \ No newline at end of file diff --git a/examples/nlp/language_modeling/conf/megatron_retro_qatask.yaml b/examples/nlp/language_modeling/conf/megatron_retro_qatask.yaml new file mode 100644 index 0000000000000..a68d11e770870 --- /dev/null +++ b/examples/nlp/language_modeling/conf/megatron_retro_qatask.yaml @@ -0,0 +1,40 @@ +inference: + greedy: False # Whether or not to use sampling ; use greedy decoding otherwise + top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature: 1.0 # sampling temperature + add_BOS: False # add the bos token at the begining of the prompt + tokens_to_generate: 30 # The minimum length of the sequence to be generated. + all_probs: False # whether return the log prob for all the tokens in vocab + repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty. + min_tokens_to_generate: 0 # The minimum length of the sequence to be generated. + compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False + end_strings: ["<|endoftext|>"] # generation will stop when one of these tokens is generated + # RETRO-specific arguments + retro_inference: + retro_gpt_retrieved_length: 128 + retro_num_neighbors: 2 + ft_neighbours: 0 + reuse_top: False + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + logger: False # logger provided by exp_manager + precision: 32 # 16, 32, or bf16 + use_distributed_sampler: False + +tensor_model_parallel_size: -1 +pipeline_model_parallel_size: -1 +pipeline_model_parallel_split_rank: -1 # used for encoder and decoder model (0 for others) +megatron_amp_O2: False # Enable O2-level automatic mixed precision to save memory + +retro_model_file: null # Retro nemo file path +checkpoint_dir: null # checkpoint file dir. This is used to load the PTL checkpoint generated during the Retro training +checkpoint_name: null # PTL checkpoint file name, only used for PTL checkpoint loading +hparams_file: null # model configuration file, only used for PTL checkpoint loading + +# qa tasks +qa_file_path: null +pred_file_path: null diff --git a/examples/nlp/language_modeling/megatron_retro_eval.py b/examples/nlp/language_modeling/megatron_retro_eval.py index 9978bab78bfc0..89e3fe9c3ddbb 100644 --- a/examples/nlp/language_modeling/megatron_retro_eval.py +++ b/examples/nlp/language_modeling/megatron_retro_eval.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,128 +12,119 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime import os -from examples.nlp.language_modeling.megatron_gpt_eval import RequestDataSet -from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer -from torch.utils.data import DataLoader +import torch +from omegaconf import OmegaConf +from pytorch_lightning.trainer.trainer import Trainer +from torch.utils.data import DataLoader, Dataset -from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel -from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam -from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector +from nemo.collections.nlp.models.language_modeling.megatron_retro_model import MegatronRetroModel +from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel +from nemo.collections.nlp.parts.nlp_overrides import CustomProgressBar, NLPDDPStrategy from nemo.core.config import hydra_runner - -try: - from megatron.core import parallel_state - - HAVE_MEGATRON_CORE = True - -except (ImportError, ModuleNotFoundError): - - HAVE_MEGATRON_CORE = False +from nemo.utils.app_state import AppState +from nemo.utils.model_utils import inject_model_parallel_rank """ -This is the script to run RETRO Model text generation. +This is the script to run Retro text generation. Usage: - Assume the model has TP=1, PP=1 - run greedy inference from a nemo file: + Currently, Mcore-based RETRO only support batch-size of 1. + Example running greedy inference from a distributed checkpoint dir: python megatron_retro_eval.py \ + checkpoint_dir=PATH_TO_CHECKPOINT \ + checkpoint_name=CHECKPOINT_NAME \ + inference.greedy=True \ + inference.add_BOS=False \ trainer.devices=1 \ trainer.num_nodes=1 \ - trainer.accelerator=gpu \ - trainer.precision=16 \ - inference.tokens_to_generate=128 \ - inference.greedy=True \ - retro_model_file=path_to_retro_nemo_file \ tensor_model_parallel_size=-1 \ pipeline_model_parallel_size=-1 \ - retrieval_service.faiss_devices='0' \ - retrieval_service.faiss_index=path_to_faiss_index \ - retrieval_service.retrieval_index=path_to_retrieval_dataset \ - retrieval_service.neighbors=20 -""" + prompt="sample prompt" \ + inference.retro_inference.retro_num_neighbors=2 \ + neighbors=["neighbor text 1", "neighbor text 2"] -@hydra_runner(config_path="conf", config_name="megatron_retro_inference") -def main(cfg) -> None: - trainer = Trainer(strategy=NLPDDPStrategy(), **cfg.trainer) + ``` +""" - model_path = cfg.retro_model_file +if not torch.cuda.is_available(): + raise EnvironmentError("GPU is needed for the inference") - save_restore_connector = NLPSaveRestoreConnector() - if os.path.isdir(model_path): - save_restore_connector.model_extracted_dir = model_path +class RequestDataSet(Dataset): + def __init__(self, sentences, neighbors): + super().__init__() + self.sentences = sentences + self.neighbors = neighbors - model_cfg = MegatronRetrievalModel.restore_from( - model_path, trainer=trainer, return_config=True, save_restore_connector=save_restore_connector, - ) + def __len__(self,): + return len(self.sentences) - with open_dict(model_cfg): - model_cfg.precision = trainer.precision - model_cfg.sequence_parallel = False - model_cfg.activations_checkpoint_granularity = None - model_cfg.activations_checkpoint_method = None - - if ( - cfg.tensor_model_parallel_size < 0 - or cfg.pipeline_model_parallel_size < 0 - or cfg.get('pipeline_model_parallel_split_rank', -1) < 0 - ): - with open_dict(cfg): - cfg.tensor_model_parallel_size = model_cfg.get('tensor_model_parallel_size', 1) - cfg.pipeline_model_parallel_size = model_cfg.get('pipeline_model_parallel_size', 1) - cfg.pipeline_model_parallel_split_rank = model_cfg.get('pipeline_model_parallel_split_rank', 0) - - model = MegatronRetrievalModel.restore_from( - model_path, trainer=trainer, save_restore_connector=save_restore_connector, override_config_path=model_cfg, - ) + def __getitem__(self, idx): + return {'prompts': self.sentences[idx], 'neighbors': self.neighbors[idx]} - length_params: LengthParam = { - "max_length": cfg.inference.tokens_to_generate, - "min_length": cfg.inference.min_tokens_to_generate, - } - sampling_params: SamplingParam = { - "use_greedy": cfg.inference.greedy, - "temperature": cfg.inference.temperature, - "top_k": cfg.inference.top_k, - "top_p": cfg.inference.top_p, - "repetition_penalty": cfg.inference.repetition_penalty, - "add_BOS": cfg.inference.add_BOS, - "all_probs": cfg.inference.all_probs, - "compute_logprob": cfg.inference.compute_logprob, - } +@hydra_runner(config_path="conf", config_name="megatron_retro_inference") +def main(cfg) -> None: + + # trainer required for restoring model parallel models + trainer = Trainer( + strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)), + **cfg.trainer, + callbacks=[CustomProgressBar()], + ) - # check whether the DDP is initialized - if not parallel_state.is_initialized(): + if cfg.checkpoint_dir: + app_state = AppState() + if cfg.tensor_model_parallel_size > 1 or cfg.pipeline_model_parallel_size > 1: + app_state.model_parallel_size = cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size + app_state.tensor_model_parallel_size = cfg.tensor_model_parallel_size + app_state.pipeline_model_parallel_size = cfg.pipeline_model_parallel_size + ( + app_state.tensor_model_parallel_rank, + app_state.pipeline_model_parallel_rank, + app_state.model_parallel_size, + app_state.data_parallel_size, + app_state.pipeline_model_parallel_split_rank, + app_state.virtual_pipeline_model_parallel_rank, + ) = fake_initialize_model_parallel( + world_size=app_state.model_parallel_size, + rank=trainer.global_rank, + tensor_model_parallel_size_=cfg.tensor_model_parallel_size, + pipeline_model_parallel_size_=cfg.pipeline_model_parallel_size, + pipeline_model_parallel_split_rank_=cfg.pipeline_model_parallel_split_rank, + ) + checkpoint_path = os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name) + # checkpoint_path is a dir in case of distributed checkpointing + if not os.path.isdir(checkpoint_path): + # legacy checkpoint needs model parallel rank injection + checkpoint_path = inject_model_parallel_rank(os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name)) + model = MegatronRetroModel.load_from_checkpoint( + checkpoint_path, hparams_file=cfg.hparams_file, trainer=trainer + ) + else: + raise ValueError("Requiring distributed checkpoint dir for loading Mcore RETRO.") - def dummy(): - return + model.freeze() - if model.trainer.strategy.launcher is not None: - model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer) - model.trainer.strategy.setup_environment() + # Have to turn off activations_checkpoint_method for inference + try: + model.model.language_model.encoder.activations_checkpoint_method = None + except AttributeError: + pass + prompt = [cfg.prompt] + neighbors = [cfg.neighbors] + ds = RequestDataSet(prompt, neighbors) + bs = 1 + request_dl = DataLoader(dataset=ds, batch_size=bs) config = OmegaConf.to_container(cfg.inference) - retrieval_service = OmegaConf.to_container(cfg.retrieval_service) - model.set_inference_config(config, retrieval_service) - - if not cfg.use_predict_method: - # First method of running text generation, call model.generate method - response = model.generate( - inputs=OmegaConf.to_container(cfg.prompts), - length_params=length_params, - sampling_params=sampling_params, - strategy=model.inference_strategy, - ) - else: - # Second method of running text generation, call trainer.predict - ds = RequestDataSet(OmegaConf.to_container(cfg.prompts)) - request_dl = DataLoader(dataset=ds, batch_size=cfg.inference_batch_size) - response = trainer.predict(model, request_dl) + model.set_inference_config(config) + + response = trainer.predict(model, request_dl) print("***************************") print(response) diff --git a/examples/nlp/language_modeling/megatron_retro_eval_legacy.py b/examples/nlp/language_modeling/megatron_retro_eval_legacy.py new file mode 100644 index 0000000000000..69222acedd343 --- /dev/null +++ b/examples/nlp/language_modeling/megatron_retro_eval_legacy.py @@ -0,0 +1,145 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from examples.nlp.language_modeling.megatron_gpt_eval import RequestDataSet +from omegaconf.omegaconf import OmegaConf, open_dict +from pytorch_lightning import Trainer +from torch.utils.data import DataLoader + +from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel +from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector +from nemo.core.config import hydra_runner + +try: + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + +""" +This is the script to run RETRO Model text generation. +(This inferencing script for native NeMo RETRO will be soon deprecated. For new inferencing script for mcore RETRO, see ./megatron_retro_eval.py) + +Usage: + Assume the model has TP=1, PP=1 + run greedy inference from a nemo file: + python megatron_retro_eval.py \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + trainer.accelerator=gpu \ + trainer.precision=16 \ + inference.tokens_to_generate=128 \ + inference.greedy=True \ + retro_model_file=path_to_retro_nemo_file \ + tensor_model_parallel_size=-1 \ + pipeline_model_parallel_size=-1 \ + retrieval_service.faiss_devices='0' \ + retrieval_service.faiss_index=path_to_faiss_index \ + retrieval_service.retrieval_index=path_to_retrieval_dataset \ + retrieval_service.neighbors=20 +""" + + +@hydra_runner(config_path="conf", config_name="megatron_retro_inference_legacy") +def main(cfg) -> None: + trainer = Trainer(strategy=NLPDDPStrategy(), **cfg.trainer) + + model_path = cfg.retro_model_file + + save_restore_connector = NLPSaveRestoreConnector() + + if os.path.isdir(model_path): + save_restore_connector.model_extracted_dir = model_path + + model_cfg = MegatronRetrievalModel.restore_from( + model_path, trainer=trainer, return_config=True, save_restore_connector=save_restore_connector, + ) + + with open_dict(model_cfg): + model_cfg.precision = trainer.precision + model_cfg.sequence_parallel = False + model_cfg.activations_checkpoint_granularity = None + model_cfg.activations_checkpoint_method = None + + if ( + cfg.tensor_model_parallel_size < 0 + or cfg.pipeline_model_parallel_size < 0 + or cfg.get('pipeline_model_parallel_split_rank', -1) < 0 + ): + with open_dict(cfg): + cfg.tensor_model_parallel_size = model_cfg.get('tensor_model_parallel_size', 1) + cfg.pipeline_model_parallel_size = model_cfg.get('pipeline_model_parallel_size', 1) + cfg.pipeline_model_parallel_split_rank = model_cfg.get('pipeline_model_parallel_split_rank', 0) + + model = MegatronRetrievalModel.restore_from( + model_path, trainer=trainer, save_restore_connector=save_restore_connector, override_config_path=model_cfg, + ) + + length_params: LengthParam = { + "max_length": cfg.inference.tokens_to_generate, + "min_length": cfg.inference.min_tokens_to_generate, + } + + sampling_params: SamplingParam = { + "use_greedy": cfg.inference.greedy, + "temperature": cfg.inference.temperature, + "top_k": cfg.inference.top_k, + "top_p": cfg.inference.top_p, + "repetition_penalty": cfg.inference.repetition_penalty, + "add_BOS": cfg.inference.add_BOS, + "all_probs": cfg.inference.all_probs, + "compute_logprob": cfg.inference.compute_logprob, + } + + # check whether the DDP is initialized + if parallel_state.is_unitialized(): + + def dummy(): + return + + if model.trainer.strategy.launcher is not None: + model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer) + model.trainer.strategy.setup_environment() + + config = OmegaConf.to_container(cfg.inference) + retrieval_service = OmegaConf.to_container(cfg.retrieval_service) + model.set_inference_config(config, retrieval_service) + + if not cfg.use_predict_method: + # First method of running text generation, call model.generate method + response = model.generate( + inputs=OmegaConf.to_container(cfg.prompts), + length_params=length_params, + sampling_params=sampling_params, + strategy=model.inference_strategy, + ) + else: + # Second method of running text generation, call trainer.predict + ds = RequestDataSet(OmegaConf.to_container(cfg.prompts)) + request_dl = DataLoader(dataset=ds, batch_size=cfg.inference_batch_size) + response = trainer.predict(model, request_dl) + + print("***************************") + print(response) + print("***************************") + + +if __name__ == '__main__': + main() diff --git a/examples/nlp/language_modeling/megatron_retro_qatask_eval.py b/examples/nlp/language_modeling/megatron_retro_qatask_eval.py new file mode 100644 index 0000000000000..b99bcafbab02f --- /dev/null +++ b/examples/nlp/language_modeling/megatron_retro_qatask_eval.py @@ -0,0 +1,217 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import json +import os + +import torch +from omegaconf import OmegaConf +from pytorch_lightning.trainer.trainer import Trainer +from torch.utils.data import DataLoader, Dataset + +from nemo.collections.nlp.data.question_answering.input_example.qa_input_example import QAExample +from nemo.collections.nlp.metrics.qa_metrics import QAMetrics +from nemo.collections.nlp.models.language_modeling.megatron_retro_model import MegatronRetroModel +from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel +from nemo.collections.nlp.parts.nlp_overrides import CustomProgressBar, NLPDDPStrategy +from nemo.core.config import hydra_runner +from nemo.utils.app_state import AppState +from nemo.utils.model_utils import inject_model_parallel_rank + +""" +This is the script to run Retro text generation for QA tasks, such as NQ, TQA. + +Usage: + Currently, Mcore-based RETRO only support batch-size of 1. + Run greedy qa task inference from a distributed checkpoint dir: + python megatron_retro_eval.py \ + checkpoint_dir=PATH_TO_CHECKPOINT \ + checkpoint_name=CHECKPOINT_NAME \ + inference.greedy=True \ + inference.add_BOS=False \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=-1 \ + pipeline_model_parallel_size=-1 \ + inference.retro_inference.retro_num_neighbors=2 \ + qa_file_path=PATH_TO_QAFILE"\ + pred_file_path =PATH_TO_PREDFILE ""\ + + + ``` +""" + +if not torch.cuda.is_available(): + raise EnvironmentError("GPU is needed for the inference") + + +class RequestDataSet(Dataset): + def __init__(self, sentences, neighbors): + super().__init__() + self.sentences = sentences + self.neighbors = neighbors + + def __len__(self,): + return len(self.sentences) + + def __getitem__(self, idx): + return {'prompts': self.sentences[idx], 'neighbors': self.neighbors[idx]} + + +def process_qasample(sample, retro_num_neighbors=2, ft_neighbours=5): + # process prompt + question = sample['question'] + if not question.endswith("?"): + question = question + "?" + processed_prompt = "Question: {} Answer: The answer is".format(question) + + # process neighbors + neighbors = sample['ctxs'] + neighbors = ["title: " + ctx["title"] + ", source: " + ctx["text"] for ctx in neighbors] + processed_neighbors = neighbors[:retro_num_neighbors] + + # # concate neighbors to prompt + if ft_neighbours > 0: + contexts = "\n\n".join(neighbors[:ft_neighbours]) + "\n\n" + processed_prompt = contexts + processed_prompt + + return processed_prompt, processed_neighbors + + +def process_qaresponse(response): + prediction = response.split("The answer is")[1] + # truncate text + prediction = prediction.split(".")[0] + prediction = prediction.split("\n")[0] + prediction = prediction.split("\n\n")[0] + return prediction + + +@hydra_runner(config_path="conf", config_name="megatron_retro_qatask") +def main(cfg) -> None: + + # trainer required for restoring model parallel models + trainer = Trainer( + strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)), + **cfg.trainer, + callbacks=[CustomProgressBar()], + ) + + if cfg.checkpoint_dir: + app_state = AppState() + if cfg.tensor_model_parallel_size > 1 or cfg.pipeline_model_parallel_size > 1: + app_state.model_parallel_size = cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size + app_state.tensor_model_parallel_size = cfg.tensor_model_parallel_size + app_state.pipeline_model_parallel_size = cfg.pipeline_model_parallel_size + ( + app_state.tensor_model_parallel_rank, + app_state.pipeline_model_parallel_rank, + app_state.model_parallel_size, + app_state.data_parallel_size, + app_state.pipeline_model_parallel_split_rank, + app_state.virtual_pipeline_model_parallel_rank, + ) = fake_initialize_model_parallel( + world_size=app_state.model_parallel_size, + rank=trainer.global_rank, + tensor_model_parallel_size_=cfg.tensor_model_parallel_size, + pipeline_model_parallel_size_=cfg.pipeline_model_parallel_size, + pipeline_model_parallel_split_rank_=cfg.pipeline_model_parallel_split_rank, + ) + checkpoint_path = os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name) + # checkpoint_path is a dir in case of distributed checkpointing + if not os.path.isdir(checkpoint_path): + # legacy checkpoint needs model parallel rank injection + checkpoint_path = inject_model_parallel_rank(os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name)) + model = MegatronRetroModel.load_from_checkpoint( + checkpoint_path, hparams_file=cfg.hparams_file, trainer=trainer + ) + else: + raise ValueError("Requiring distributed checkpoint dir for loading Mcore RETRO.") + + model.freeze() + + # Have to turn off activations_checkpoint_method for inference + try: + model.model.language_model.encoder.activations_checkpoint_method = None + except AttributeError: + pass + + # Reading QA data files + qa_samples = [] + with open(cfg.qa_file_path, 'r', encoding='utf-8') as f: + qa_samples = json.load(f) + + # Processing prompts and neighbors + prompts = [] + neighbors = [] + ground_truths = [] + for sample in qa_samples: + processed_prompt, processed_neighbors = process_qasample( + sample, cfg.inference.retro_inference.retro_num_neighbors, cfg.inference.retro_inference.ft_neighbours + ) + prompts.append(processed_prompt) + neighbors.append(processed_neighbors) + ground_truths.append( + sample['answers'][0] + ) # Boxin only takes the first value of sample['answers'] (https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/boxin/instructretro-internal-test/tools/retro/text_generation/evaluate.py?ref_type=heads#L85) + + # Running prediction + bs = 1 + ds = RequestDataSet(prompts, neighbors) + request_dl = DataLoader(dataset=ds, batch_size=bs) + config = OmegaConf.to_container(cfg.inference) + model.set_inference_config(config) + response = trainer.predict(model, request_dl) + + # Generating answers + print("***************************") + with open(cfg.pred_file_path, "w", encoding="utf-8") as pred_file: + for i in range(len(response)): + for sent in response[i]["sentences"]: + sent = sent.strip() + sent = sent.replace("\n", " ") + pred_file.write(sent + "\n") + for neighbor in neighbors[i]: + neighbor = neighbor.replace("\n", " ") + neighbor = "Neighbor: " + neighbor + pred_file.write(neighbor + "\n") + pred_file.write("---------\n") + print(f"Inference Complete, prediction file saved at {cfg.pred_file_path}") + print("***************************") + + # Compute metrics + predictions = [process_qaresponse(response[i]["sentences"][0]) for i in range(len(response))] + formatted_ground_truths = [] + formatted_predictions = [] + for i in range(len(predictions)): # formatting to use NeMo's QAMetrics methods + question_id = i + qaexample = QAExample( + qas_id=question_id, + answers=[{'text': ground_truths[i]}], + question_text="", + context_text="", + context_id="", + answer_text="", + start_position_character="", + title="", + ) + formatted_ground_truths.append(qaexample) + formatted_predictions.append(predictions[i]) + eval_results = QAMetrics.evaluate_predictions(formatted_ground_truths, formatted_predictions) + print("Eval_results: ", eval_results) + + +if __name__ == '__main__': + main() diff --git a/nemo/collections/nlp/models/language_modeling/megatron_retro_model.py b/nemo/collections/nlp/models/language_modeling/megatron_retro_model.py index 8cc39056554ce..377ccbee163bf 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_retro_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_retro_model.py @@ -219,10 +219,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] inference_config = self.get_inference_config() - if torch.distributed.get_rank() == 0: - logging.info("inference_config: ") - logging.info(inference_config) - if inference_config is None: return None else: @@ -359,9 +355,9 @@ def get_forward_output_only_func(self): def fwd_output_only_func(dataloader_iter, model): batch = next(dataloader_iter) extra_arg = {} - if len(batch) == 5: + if len(batch) == 6: batch = [x.cuda() for x in batch] - tokens, attention_mask, position_ids, context_input_ids, context_position_ids, context_mask = batch + tokens, attention_mask, position_ids, context_input_ids, context_mask, context_position_ids = batch attention_mask = attention_mask[0:1] else: ( @@ -369,26 +365,21 @@ def fwd_output_only_func(dataloader_iter, model): attention_mask, position_ids, context_input_ids, - context_position_ids, context_mask, + context_position_ids, set_inference_key_value_memory, inference_max_sequence_len, ) = batch + # Transfer needed data to GPU tokens = tokens.cuda() position_ids = position_ids.cuda() - if attention_mask is not None: - attention_mask = attention_mask.cuda() - attention_mask = attention_mask[0:1] context_input_ids = context_input_ids.cuda() context_position_ids = context_position_ids.cuda() context_mask = None if self.mcore_gpt: - # if first step, then clear KV cache, otherwise reuse inference_paarms - if set_inference_key_value_memory[0].item(): - self.inference_params = InferenceParams( - max_batch_size=tokens.size(0), max_sequence_length=inference_max_sequence_len[0].item() - ) - extra_arg['inference_params'] = self.inference_params + # No caching key, value because currently it's not supported for mcore RETRO in NeMo + pass + else: extra_arg['set_inference_key_value_memory'] = set_inference_key_value_memory[0].item() extra_arg['inference_max_sequence_len'] = inference_max_sequence_len[0].item() diff --git a/nemo/collections/nlp/modules/common/text_generation_strategy.py b/nemo/collections/nlp/modules/common/text_generation_strategy.py index 3abfda2a5e446..e29bb3423c4a5 100644 --- a/nemo/collections/nlp/modules/common/text_generation_strategy.py +++ b/nemo/collections/nlp/modules/common/text_generation_strategy.py @@ -22,6 +22,7 @@ import torch from nemo.collections.nlp.modules.common.lm_utils import pad_batch +from nemo.collections.nlp.modules.common.megatron.module import Float16Module from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids try: @@ -34,6 +35,8 @@ try: from megatron.core.pipeline_parallel.schedules import get_forward_backward_func + from megatron.core.transformer.identity_op import IdentityOp + from megatron.core.transformer.module import Float16Module as MCoreFloat16Module HAVE_MEGATRON_CORE = True @@ -593,6 +596,225 @@ def post_process(self, tokens: torch.Tensor, new_tokens: torch.Tensor, context_l tokens[:, :context_length][(tokens[:, :context_length] >= pseudo_token_ids_start)] = tokenizer.unk_id +class McoreRetroModelTextGenerationStrategy(TextGenerationStrategy): + def __init__(self, model): + super().__init__(model) + self.forward_model = self.model.model + + def clip_max_len(self, maxlen: int) -> int: + """ clip the max len based on the LM model max sequence length""" + + # for positional embedding types that allow length extrapolation, don't clip the max length + if self.model.cfg.get("position_embedding_type", "learned_absolute") == "learned_absolute": + if maxlen > self.model.cfg.encoder_seq_length + 1: + maxlen = self.model.cfg.encoder_seq_length + 1 + return maxlen + + def tokenize_batch(self, sentences, max_len, add_BOS): + """ + convert the sentences into lists of tokens, pad them to the same length, add bos tokens if it is needed + Args: + sentences (List[str]): list of input sentences in str format. + max_len (int): max number of tokens to generate. + add_BOS (bool): whether to add the BOS token at the beginning + Returns: + Tuple[torch.Tensor], the tokenized and padded torch tensor and the token context length tensor. + """ + tokenizer = self.model.tokenizer + if add_BOS: + context_tokens = [[tokenizer.bos_id] + tokenizer.text_to_ids(s) for s in sentences] + else: + context_tokens = [tokenizer.text_to_ids(s) for s in sentences] + + # attention, not pad_batch, padding will be done at init_batch + context_tokens, context_lengths = pad_batch(batch=context_tokens, pad_id=tokenizer.eos_id, max_len=0) + + context_tokens_tensor = torch.cuda.LongTensor(context_tokens) + context_length_tensor = torch.cuda.LongTensor(context_lengths) + return context_tokens_tensor, context_length_tensor + + def tokenize_neighbors_batch(self, neighbors, retro_args): + tokenizer = self.model.tokenizer + r = retro_args['retro_gpt_retrieved_length'] + retro_num_neighbors = retro_args['retro_num_neighbors'] + ft_neighbours = retro_args['ft_neighbours'] + reuse_top = retro_args['reuse_top'] + + padded_valid_neighbours_tokens = [] + for i in range(len(neighbors)): + onesample_neighbors = neighbors[i] + + # tokenize neighbors + onesample_neighbors_tokens = [] + for neighbor in onesample_neighbors: + onesample_neighbors_tokens.append(tokenizer.text_to_ids(neighbor)) + + # take top k neighbours + if reuse_top: + valid_onesample_neighbours_tokens = onesample_neighbors_tokens[:retro_num_neighbors] + else: + valid_onesample_neighbours_tokens = onesample_neighbors_tokens[ + ft_neighbours : retro_num_neighbors + ft_neighbours + ] + + # pad neighbors + padded_valid_onesample_neighbours_tokens = [] + for neighbour_tokens in valid_onesample_neighbours_tokens: + if len(neighbour_tokens) >= r: + padded_onesample_neighbour_tokens = neighbour_tokens[:r] + else: + padded_onesample_neighbour_tokens = neighbour_tokens + [tokenizer.eos_id] * ( + r - len(neighbour_tokens) + ) + padded_valid_onesample_neighbours_tokens.append(padded_onesample_neighbour_tokens) + + # check if have enough neighbors + if len(padded_valid_onesample_neighbours_tokens) < retro_num_neighbors: + assert ValueError("neighbours are not enough, add empty ones and create mask for those empty ones") + + # append to batch + padded_valid_neighbours_tokens.append(padded_valid_onesample_neighbours_tokens) + + # cast to torch tensor + padded_valid_neighbours_tokens = torch.cuda.LongTensor(padded_valid_neighbours_tokens) + padded_valid_neighbours_tokens_shape = torch.cuda.LongTensor(padded_valid_neighbours_tokens.shape) + + return padded_valid_neighbours_tokens, padded_valid_neighbours_tokens_shape + + def init_batch(self, context_tokens: torch.Tensor, context_length: int, compute_attention_mask: bool, **extra): + """initialize the batch data before the inference steps.""" + + # For Mcore retrieval RETRO model, modify tokens and neighbors to set them into 2 chunks, one for question, and one for answer, both having the same length of context_tokens.shape[1] + bs, context_tokens_length = context_tokens.shape + assert bs == 1 # similar to M-LM RETRO inference code, currently only support batch_size=1 + context_tokens = [context_tokens[0].tolist() + [self.model.tokenizer.eos_id] * context_tokens_length] + context_tokens = torch.cuda.LongTensor(context_tokens) + self.model.model.config.retro_gpt_chunk_length = context_tokens_length # set RetroConfig of M-LM's RETRO model + # reshape tensor extra['neighbors_tokens'] (currently: [k, 1, r]) to [bs, l, k, r] + neighbors_tokens = extra['neighbors_tokens'] + neighbors_tokens = neighbors_tokens.permute(1, 0, 2) + neighbors_tokens = neighbors_tokens.unsqueeze(0) + # duplicate into 2 chunks from [bs, l, k ,r] to [bs, 2*l, k ,r] + neighbors_tokens = neighbors_tokens.repeat(1, 2, 1, 1) + + # Move to GPU. + tokenizer = self.model.tokenizer + tokens = context_tokens.contiguous().cuda() + neighbors_tokens = neighbors_tokens.contiguous().cuda() + + # Get the attention mask and postition ids. + self.attention_mask, _, self.position_ids = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eos_id, + self.model.cfg.get('reset_position_ids', False), + self.model.cfg.get('reset_attention_mask', False), + self.model.cfg.get('eod_mask_loss', False), + compute_attention_mask=compute_attention_mask, + ) + + # Get the attention mask and postition ids for neighbors (retro_generation.retro_generate_tokens_probs_and_return_on_first_stage) + # Reshape neighbors_tokens tensor to 2D for get_ltor_masks_and_position_ids and as forward arg of RETRO model, original shape is 3D ([bs, k, r]) + [bs, l, k, r] = neighbors_tokens.shape + neighbors_tokens = neighbors_tokens.view(-1, r).long() + + _, _, self.neighbor_position_ids = get_ltor_masks_and_position_ids( + neighbors_tokens, + tokenizer.eos_id, + self.model.cfg.get('reset_position_ids', False), + self.model.cfg.get('reset_attention_mask', False), + self.model.cfg.get('eod_mask_loss', False), + ) + self.neighbor_attention_mask = torch.zeros( + [1, 1] + ) # dummy value, since the batch neighbor_attention_mask will be set to None in megatron_retro_model.py in Mcore implementation + self.neighbors_tokens = neighbors_tokens + + # For Mcore retrieval RETRO model, following ADLR's Mcore RETRO inferencing implementation, updating the arguments inside RETRO model (retro_num_neighbors, retro_chunk_length) with the inference's sample + inference_retro_num_neighbors = k + inference_retro_chunk_length = context_tokens_length + inference_retro_retrieved_length = r + self.forward_model.config.retro_num_neighbors = inference_retro_num_neighbors + self.forward_model.config.retro_chunk_length = inference_retro_chunk_length + self.forward_model.config.retro_retrieved_length = inference_retro_retrieved_length + contain_encoder = True + if isinstance(self.forward_model, (Float16Module, MCoreFloat16Module)): + layers = self.forward_model.module.decoder.layers + else: + layers = self.forward_model.decoder.layers + for layer in layers: + if not (isinstance(layer.cross_attention, IdentityOp)): # if this is encoder-decoder cross-attention layer + # updating RetroDecoder (RetroDecoderCrossAttention, RetroDecoderBiasDropoutAdd) + layer.cross_attention.retro_num_neighbors = inference_retro_num_neighbors + layer.cross_attention.retro_chunk_length = inference_retro_chunk_length + layer.cross_attention.retro_retrieved_length = inference_retro_retrieved_length + layer.cross_attn_bda.retro_chunk_length = inference_retro_chunk_length + + # updating RetroEncoder (RetroEncoderCrossAttention, RetroEncoderBiasDropoutAdd, RetroEncoderLayerNorm) + if contain_encoder: # the first cross-attention decoder layer contain encoder + layer.cross_attention.encoder.layers[ + 0 + ].cross_attention.retro_num_neighbors = inference_retro_num_neighbors + layer.cross_attention.encoder.layers[ + 0 + ].cross_attention.retro_chunk_length = inference_retro_chunk_length + layer.cross_attention.encoder.layers[ + 0 + ].cross_attention.retro_retrieved_length = inference_retro_retrieved_length + layer.cross_attention.encoder.layers[ + 0 + ].cross_attn_bda.retro_num_neighbors = inference_retro_num_neighbors + layer.cross_attention.encoder.layers[ + 0 + ].pre_mlp_layernorm.retro_num_neighbors = inference_retro_num_neighbors + contain_encoder = False + + return context_tokens + + def prepare_batch_at_step( + self, + tokens: torch.Tensor, + maxlen: int, + micro_batch_size: int, + step: int, + context_length: int, + compute_attention_mask: bool = True, + **extra, + ) -> Tuple[List[torch.Tensor], List[int]]: + """ + generate the batch used in inference for each of the steps + """ + + # For Mcore retrieval RETRO model, currently not support memory caching, always allocate memory for the entire context + # Allocate memory for the entire context. + set_inference_key_value_memory = True + tokens2use = tokens + positions2use = self.position_ids + attention_mask2use = self.attention_mask + + """Prepare batch for each of the inference steps""" + attention_mask_repeat = None + if compute_attention_mask: + attention_mask_repeat = torch.concat([attention_mask2use for _ in range(micro_batch_size)]) + + setkey_value_array = torch.tensor( + [set_inference_key_value_memory] * micro_batch_size, device=torch.cuda.current_device() + ) + len_array = torch.tensor([maxlen] * micro_batch_size, device=torch.cuda.current_device()) + + batch = [ + tokens2use, + attention_mask_repeat, + positions2use, + self.neighbors_tokens, + self.neighbor_attention_mask, + self.neighbor_position_ids, + setkey_value_array, + len_array, + ] + tensor_shape = [tokens2use.shape[1], micro_batch_size, self.model.cfg.hidden_size] + return batch, tensor_shape + + def model_inference_strategy_dispatcher(model, **args): from nemo.collections.multimodal.models.multimodal_llm.neva.neva_model import MegatronNevaModel from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -600,6 +822,7 @@ def model_inference_strategy_dispatcher(model, **args): MegatronGPTPromptLearningModel, ) from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel + from nemo.collections.nlp.models.language_modeling.megatron_retro_model import MegatronRetroModel from nemo.collections.nlp.modules.common.retro_inference_strategies import ( RetroFileQAModelTextGenerationStrategy, RetroModelTextGenerationStrategy, @@ -610,7 +833,7 @@ def model_inference_strategy_dispatcher(model, **args): return NevaModelTextGenerationStrategy(model) if isinstance(model, MegatronGPTPromptLearningModel): return PromptLearningModelTextGenerationStrategy(model, **args) - elif isinstance(model, MegatronGPTModel): + elif isinstance(model, MegatronGPTModel) and not (isinstance(model, MegatronRetroModel)): return GPTModelTextGenerationStrategy(model) elif isinstance(model, MegatronRetrievalModel): strategy_name = args['strategy'] @@ -625,6 +848,8 @@ def model_inference_strategy_dispatcher(model, **args): return RetroFileQAModelTextGenerationStrategy(model, **args) else: raise ValueError(f'{strategy_name} is not supported for inference') + elif isinstance(model, MegatronRetroModel): + return McoreRetroModelTextGenerationStrategy(model) else: raise ValueError(f'{model} is not supported for inference') diff --git a/nemo/collections/nlp/modules/common/text_generation_utils.py b/nemo/collections/nlp/modules/common/text_generation_utils.py index 3daf93ac0ed2a..d130322404b6d 100644 --- a/nemo/collections/nlp/modules/common/text_generation_utils.py +++ b/nemo/collections/nlp/modules/common/text_generation_utils.py @@ -473,6 +473,7 @@ def synced_generate( end_strings=[], min_tokens_to_generate=0, image_list=None, + **strategy_args, ): context_length = context_length_tensor.min().item() tokenizer = model.tokenizer @@ -488,6 +489,19 @@ def synced_generate( temperature=temperature, ) else: + + extra = { + "top_p": top_p, + "top_k": top_k, + "greedy": greedy, + "repetition_penalty": repetition_penalty, + "min_tokens_to_generate": min_tokens_to_generate, + } + + # if input containing neighbors (for Mcore retrieval RETRO model) + if "neighbors_tokens" in strategy_args: + extra['neighbors_tokens'] = strategy_args['neighbors_tokens'] + batch_token_iterator = sample_sequence_batch( model, inference_strategy, @@ -500,13 +514,7 @@ def synced_generate( temperature=temperature, end_strings=end_strings, image_list=image_list, - extra={ - "top_p": top_p, - "top_k": top_k, - "greedy": greedy, - "repetition_penalty": repetition_penalty, - "min_tokens_to_generate": min_tokens_to_generate, - }, + extra=extra, ) for tokens, lengths, output_logits, full_logits in batch_token_iterator: @@ -626,6 +634,22 @@ def generate( end_strings, random_seed, ) + + # tokenize neighbors and broadcast (for Mcore retrieval RETRO model) + if 'neighbors' in strategy_args: + # tokenize neighbors + neighbors_tokens_tensor, neighbors_tokens_tensor_shape = inference_strategy.tokenize_neighbors_batch( + strategy_args['neighbors'], strategy_args['retro_inference'] + ) + + # send neighbors tensors to all ranks + model_parallel_group = parallel_state.get_model_parallel_group() + src = get_model_parallel_src_rank() + torch.distributed.broadcast(neighbors_tokens_tensor_shape, src, model_parallel_group) + torch.distributed.broadcast(neighbors_tokens_tensor, src, model_parallel_group) + else: + neighbors_tokens_tensor = None + else: ( context_length_tensor, @@ -643,6 +667,27 @@ def generate( random_seed, ) = receive_generate_info() + # receive broadcast (for Mcore retrieval RETRO model) + if 'neighbors' in strategy_args: + # receive neighbors tensors to all ranks + model_parallel_group = parallel_state.get_model_parallel_group() + src = get_model_parallel_src_rank() + neighbors_tokens_tensor_shape = torch.empty(2, dtype=torch.float32, device=torch.cuda.current_device()) + torch.distributed.broadcast(neighbors_tokens_tensor_shape, src, model_parallel_group) + neighbors_tokens_tensor = torch.empty( + neighbors_tokens_tensor_shape[0], + neighbors_tokens_tensor_shape[1], + dtype=torch.int64, + device=torch.cuda.current_device(), + ) + torch.distributed.broadcast(neighbors_tokens_tensor, src, model_parallel_group) + else: + neighbors_tokens_tensor = None + + # add neighbors to strategy_args (for retrieval RETRO model) + if 'neighbors' in strategy_args: + strategy_args['neighbors_tokens'] = neighbors_tokens_tensor + if random_seed is not None: seed_everything(random_seed) @@ -663,6 +708,7 @@ def generate( end_strings=end_strings, min_tokens_to_generate=min_tokens_to_generate, image_list=image_list, + **strategy_args, ) special_tokens = set() if hasattr(tokenizer, 'pad_token') and tokenizer.pad_token is not None: @@ -771,7 +817,15 @@ def sample_sequence_batch( # initialize the batch with torch.no_grad(): context_length = context_lengths.min().item() - inference_strategy.init_batch(context_tokens, context_length, compute_attention_mask) + if 'neighbors_tokens' in extra: # for Mcore retrieval RETRO model + + # For Mcore retrieval RETRO model, context_tokens tensors are updated after init_batch() (the length is doubled after processing) + context_tokens = inference_strategy.init_batch( + context_tokens, context_length, compute_attention_mask, **extra + ) + + else: + inference_strategy.init_batch(context_tokens, context_length, compute_attention_mask) # added eos_id to support the function generate_samples_eval that passes # eos_id as an argument and needs termination when that id id found. eod_id = tokenizer.eos_id @@ -809,7 +863,11 @@ def sample_sequence_batch( logits = output[:, -1].view(batch_size, -1).contiguous() else: - logits = output[0]['logits'][:, -1].contiguous() + if 'neighbors_tokens' in extra: # for Mcore retrieval RETRO model + # for Mcore RETRO inference, disimilar to GPT, we will get the logits of the (context_length - 1)th token, instead of the last token + logits = output[0]['logits'][:, context_length - 1].contiguous() + else: + logits = output[0]['logits'][:, -1].contiguous() logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) assert logits is not None logits = logits.view(batch_size, -1)