Skip to content

Commit

Permalink
Return instruction and response if n_turns==1
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Jul 26, 2024
1 parent 90909ab commit 61d8de1
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
19 changes: 16 additions & 3 deletions src/distilabel/steps/tasks/magpie/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,19 +132,30 @@ def _generate_instruction(
def _prepare_conversation_outputs(
self, conversations: List["ChatType"]
) -> List[Dict[str, Any]]:
"""Prepare the output conversation removing the system prompt if necessary.
"""Prepare the output conversation removing the system prompt if necessary. If
`n_turns==1`, then it will return a dictionary with "instruction" and "response"
keys. Otherwise, it will return a dictionary with a "conversation" key.
Args:
conversations: the list of generated conversations.
Returns:
A list of dictionaries containing a "conversation" key.
A list of dictionaries containing a "conversation" key or "instruction" and
"responses" key.
"""
outputs = []
for conversation in conversations:
if not self.include_system_prompt and conversation[0]["role"] == "system":
conversation.pop(0)
outputs.append({"conversation": conversation})
if self.n_turns == 1 and len(conversation) == 2:
outputs.append(
{
"instruction": conversation[0]["content"],
"response": conversation[1]["content"],
}
)
else:
outputs.append({"conversation": conversation})
return outputs

def _generate_conversation_turn(
Expand Down Expand Up @@ -425,6 +436,8 @@ def outputs(self) -> List[str]:
"""Either a multi-turn conversation or the instruction generated."""
if self.only_instruction:
return ["instruction", "model_name"]
if self.n_turns == 1:
return ["instruction", "response", "model_name"]
return ["conversation", "model_name"]

def format_output(
Expand Down
2 changes: 2 additions & 0 deletions src/distilabel/steps/tasks/magpie/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ def outputs(self) -> List[str]:
"""Either a multi-turn conversation or the instruction generated."""
if self.only_instruction:
return ["instruction", "model_name"]
if self.n_turns == 1:
return ["instruction", "response", "model_name"]
return ["conversation", "model_name"]

def process(self, offset: int = 0) -> "GeneratorStepOutput":
Expand Down

0 comments on commit 61d8de1

Please sign in to comment.