Skip to content

Commit

Permalink
Include ID in text output of LM generate (#668)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Jul 13, 2024
1 parent 53ef569 commit a662709
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
11 changes: 7 additions & 4 deletions src/fairseq2/datasets/instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def create_reader(
builder.map(target_encoder, selector="tgt", num_parallel_calls=npc)

def cat_source_and_target(example: Dict[str, Any]) -> Dict[str, Any]:
sample_id = example.get("id", None)
id_ = example.get("id")

prompt_indices = example["src"]
target_indices = example["tgt"]
Expand All @@ -303,7 +303,7 @@ def cat_source_and_target(example: Dict[str, Any]) -> Dict[str, Any]:

target_mask = torch.arange(len(indices)) >= len(prompt_indices)

return {"indices": indices, "target_mask": target_mask, "id": sample_id}
return {"id": id_, "indices": indices, "target_mask": target_mask}

builder.map(cat_source_and_target, num_parallel_calls=npc)

Expand Down Expand Up @@ -398,10 +398,13 @@ def create_prompt_reader(
text_encoder = tokenizer.create_encoder(mode="prompt")

def encode(example: Dict[str, Any]) -> Dict[str, Any]:
sample_id = example.get("id", None)
id_ = example.get("id")

prompt = example["src"]

return {"prompt": prompt, "indices": text_encoder(prompt), "id": sample_id}
indices = text_encoder(prompt)

return {"id": id_, "prompt": prompt, "indices": indices}

builder.map(encode, num_parallel_calls=npc)

Expand Down
16 changes: 9 additions & 7 deletions src/fairseq2/recipes/lm/text_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def __call__(self, batch: SequenceBatch) -> None:
except KeyError:
raise ValueError("`batch.example` must contain a 'prompt' item.")

sample_ids = batch.example["id"]
ids = batch.example["id"]

output = self._generator(batch.seqs, batch.padding_mask)

Expand All @@ -413,9 +413,7 @@ def __call__(self, batch: SequenceBatch) -> None:
if not self._text_output_stream and not self._json_output_stream:
return

for sample_id, prompt, hypotheses in zip(
sample_ids, prompts, output.hypotheses
):
for id_, prompt, hypotheses in zip(ids, prompts, output.hypotheses):
if len(hypotheses) == 0:
raise RuntimeError(
"The sequence generator returned no hypothesis. Please file a bug report."
Expand All @@ -441,6 +439,12 @@ def __call__(self, batch: SequenceBatch) -> None:

# Dump as text.
if stream := self._text_output_stream:
if id_ is not None:
stream.write("<<<<< ID >>>>>")
stream.write("\n")
stream.write(f"{id_}")
stream.write("\n\n")

stream.write("<<<<< PROMPT >>>>>")
stream.write("\n")
stream.write(prompt)
Expand All @@ -459,22 +463,20 @@ def __call__(self, batch: SequenceBatch) -> None:
stream.write("\n\n")
stream.write("<<<<< SCORE >>>>>")
stream.write("\n")

stream.write(f"{score:.8f}")

if step_scores is not None:
stream.write("\n\n")
stream.write("<<<<< STEP SCORES >>>>>")
stream.write("\n")

stream.write(", ".join(f"{s:.8f}" for s in step_scores))

stream.write("\n\n\n============================\n\n\n")

# Dump as JSON.
if stream := self._json_output_stream:
json_output = {
"id": sample_id,
"id": id_,
"prompt": prompt,
"response": response,
"token_indices": token_indices,
Expand Down

0 comments on commit a662709

Please sign in to comment.