From a4c14489c748a6cea9e173e422d1060486e3d05e Mon Sep 17 00:00:00 2001 From: Aman Jaiswal Date: Sun, 21 Jul 2024 17:48:05 -0300 Subject: [PATCH] enhancement: add_raw_input (#698) --- src/distilabel/steps/tasks/base.py | 35 +++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/src/distilabel/steps/tasks/base.py b/src/distilabel/steps/tasks/base.py index 4479954fa5..f82f86d8f8 100644 --- a/src/distilabel/steps/tasks/base.py +++ b/src/distilabel/steps/tasks/base.py @@ -47,6 +47,8 @@ class _Task(_Step, ABC): add_raw_output: whether to include a field with the raw output of the LLM in the `distilabel_metadata` field of the output. Can be helpful to not loose data with `Tasks` that need to format the output of the `LLM`. Defaults to `False`. + add_raw_input: whether to include a field with the raw input to the LLM in the + `distilabel_metadata` field of the output. num_generations: The number of generations to be produced per input. """ @@ -59,7 +61,14 @@ class _Task(_Step, ABC): "Whether to include the raw output of the LLM in the key `raw_output_`" " of the `distilabel_metadata` dictionary output column" ), - ) + ), + add_raw_input: RuntimeParameter[bool] = Field( + default=False, + description=( + "Whether to include the raw input to the LLM in the key `raw_input_`" + " of the `distilabel_metadata` dictionary output column" + ), + ), num_generations: RuntimeParameter[int] = Field( default=1, description="The number of generations to be produced per input." ) @@ -115,6 +124,11 @@ def _format_outputs( output, add_raw_output=self.add_raw_output, # type: ignore ) + formatted_output = self._maybe_add_raw_input( + formatted_output, + input, + add_raw_input=self._maybe_add_raw_input, # type: ignore + ) formatted_outputs.append(formatted_output) except Exception as e: self._logger.warning( # type: ignore @@ -131,12 +145,18 @@ def _output_on_failure( """ # Create a dictionary with the outputs of the task (every output set to None) outputs = {output: None for output in self.outputs} + inputs = {input: None for input in self.inputs} outputs["model_name"] = self.llm.model_name # type: ignore outputs = self._maybe_add_raw_output( outputs, output, add_raw_output=self.add_raw_output, # type: ignore ) + outputs = self._maybe_add_raw_input( + outputs, + inputs, + add_raw_input=self.add_raw_input, # type: ignore + ) return outputs def _maybe_add_raw_output( @@ -152,6 +172,19 @@ def _maybe_add_raw_output( output[DISTILABEL_METADATA_KEY] = meta return output + def _maybe_add_raw_input( + self, + input: Dict[str, Any], + raw_input: Union[str, None], + add_raw_input: bool = False, + ) -> Dict[str, Any]: + """Adds the raw input to the LLM to the output dictionary if `add_raw_input` is True.""" + if add_raw_input: + meta = input.get(DISTILABEL_METADATA_KEY, {}) + meta[f"raw_input_{self.name}"] = raw_input + input[DISTILABEL_METADATA_KEY] = meta + return input + class Task(_Task, Step): """Task is a class that implements the `_Task` abstract class and adds the `Step`