From a662709581c3641f842e79fbebca071287ef6b45 Mon Sep 17 00:00:00 2001 From: Can Balioglu Date: Sat, 13 Jul 2024 13:43:13 -0400 Subject: [PATCH] Include ID in text output of LM generate (#668) --- src/fairseq2/datasets/instruction.py | 11 +++++++---- src/fairseq2/recipes/lm/text_generate.py | 16 +++++++++------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/fairseq2/datasets/instruction.py b/src/fairseq2/datasets/instruction.py index b5ee1865e..550ec3622 100644 --- a/src/fairseq2/datasets/instruction.py +++ b/src/fairseq2/datasets/instruction.py @@ -294,7 +294,7 @@ def create_reader( builder.map(target_encoder, selector="tgt", num_parallel_calls=npc) def cat_source_and_target(example: Dict[str, Any]) -> Dict[str, Any]: - sample_id = example.get("id", None) + id_ = example.get("id") prompt_indices = example["src"] target_indices = example["tgt"] @@ -303,7 +303,7 @@ def cat_source_and_target(example: Dict[str, Any]) -> Dict[str, Any]: target_mask = torch.arange(len(indices)) >= len(prompt_indices) - return {"indices": indices, "target_mask": target_mask, "id": sample_id} + return {"id": id_, "indices": indices, "target_mask": target_mask} builder.map(cat_source_and_target, num_parallel_calls=npc) @@ -398,10 +398,13 @@ def create_prompt_reader( text_encoder = tokenizer.create_encoder(mode="prompt") def encode(example: Dict[str, Any]) -> Dict[str, Any]: - sample_id = example.get("id", None) + id_ = example.get("id") + prompt = example["src"] - return {"prompt": prompt, "indices": text_encoder(prompt), "id": sample_id} + indices = text_encoder(prompt) + + return {"id": id_, "prompt": prompt, "indices": indices} builder.map(encode, num_parallel_calls=npc) diff --git a/src/fairseq2/recipes/lm/text_generate.py b/src/fairseq2/recipes/lm/text_generate.py index f2b470c86..5d0c7f46f 100644 --- a/src/fairseq2/recipes/lm/text_generate.py +++ b/src/fairseq2/recipes/lm/text_generate.py @@ -403,7 +403,7 @@ def __call__(self, batch: SequenceBatch) -> None: except KeyError: raise ValueError("`batch.example` must contain a 'prompt' item.") - sample_ids = batch.example["id"] + ids = batch.example["id"] output = self._generator(batch.seqs, batch.padding_mask) @@ -413,9 +413,7 @@ def __call__(self, batch: SequenceBatch) -> None: if not self._text_output_stream and not self._json_output_stream: return - for sample_id, prompt, hypotheses in zip( - sample_ids, prompts, output.hypotheses - ): + for id_, prompt, hypotheses in zip(ids, prompts, output.hypotheses): if len(hypotheses) == 0: raise RuntimeError( "The sequence generator returned no hypothesis. Please file a bug report." @@ -441,6 +439,12 @@ def __call__(self, batch: SequenceBatch) -> None: # Dump as text. if stream := self._text_output_stream: + if id_ is not None: + stream.write("<<<<< ID >>>>>") + stream.write("\n") + stream.write(f"{id_}") + stream.write("\n\n") + stream.write("<<<<< PROMPT >>>>>") stream.write("\n") stream.write(prompt) @@ -459,14 +463,12 @@ def __call__(self, batch: SequenceBatch) -> None: stream.write("\n\n") stream.write("<<<<< SCORE >>>>>") stream.write("\n") - stream.write(f"{score:.8f}") if step_scores is not None: stream.write("\n\n") stream.write("<<<<< STEP SCORES >>>>>") stream.write("\n") - stream.write(", ".join(f"{s:.8f}" for s in step_scores)) stream.write("\n\n\n============================\n\n\n") @@ -474,7 +476,7 @@ def __call__(self, batch: SequenceBatch) -> None: # Dump as JSON. if stream := self._json_output_stream: json_output = { - "id": sample_id, + "id": id_, "prompt": prompt, "response": response, "token_indices": token_indices,