Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhancement: add_raw_input (#698) #799

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion src/distilabel/steps/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class _Task(_Step, ABC):
add_raw_output: whether to include a field with the raw output of the LLM in the
`distilabel_metadata` field of the output. Can be helpful to not loose data
with `Tasks` that need to format the output of the `LLM`. Defaults to `False`.
add_raw_input: whether to include a field with the raw input to the LLM in the
`distilabel_metadata` field of the output.
num_generations: The number of generations to be produced per input.
"""

Expand All @@ -59,7 +61,14 @@ class _Task(_Step, ABC):
"Whether to include the raw output of the LLM in the key `raw_output_<TASK_NAME>`"
" of the `distilabel_metadata` dictionary output column"
),
)
),
add_raw_input: RuntimeParameter[bool] = Field(
default=False,
description=(
"Whether to include the raw input to the LLM in the key `raw_input_<TASK_NAME>`"
" of the `distilabel_metadata` dictionary output column"
),
),
num_generations: RuntimeParameter[int] = Field(
default=1, description="The number of generations to be produced per input."
)
Expand Down Expand Up @@ -115,6 +124,11 @@ def _format_outputs(
output,
add_raw_output=self.add_raw_output, # type: ignore
)
formatted_output = self._maybe_add_raw_input(
formatted_output,
input,
add_raw_input=self._maybe_add_raw_input, # type: ignore
)
formatted_outputs.append(formatted_output)
except Exception as e:
self._logger.warning( # type: ignore
Expand All @@ -131,12 +145,18 @@ def _output_on_failure(
"""
# Create a dictionary with the outputs of the task (every output set to None)
outputs = {output: None for output in self.outputs}
inputs = {input: None for input in self.inputs}
outputs["model_name"] = self.llm.model_name # type: ignore
outputs = self._maybe_add_raw_output(
outputs,
output,
add_raw_output=self.add_raw_output, # type: ignore
)
outputs = self._maybe_add_raw_input(
outputs,
inputs,
add_raw_input=self.add_raw_input, # type: ignore
)
return outputs

def _maybe_add_raw_output(
Expand All @@ -152,6 +172,19 @@ def _maybe_add_raw_output(
output[DISTILABEL_METADATA_KEY] = meta
return output

def _maybe_add_raw_input(
self,
input: Dict[str, Any],
raw_input: Union[str, None],
add_raw_input: bool = False,
) -> Dict[str, Any]:
"""Adds the raw input to the LLM to the output dictionary if `add_raw_input` is True."""
if add_raw_input:
meta = input.get(DISTILABEL_METADATA_KEY, {})
meta[f"raw_input_{self.name}"] = raw_input
input[DISTILABEL_METADATA_KEY] = meta
return input


class Task(_Task, Step):
"""Task is a class that implements the `_Task` abstract class and adds the `Step`
Expand Down