Skip to content

Commit

Permalink
Add end_with_user and include_system_prompt flags to Magpie tas…
Browse files Browse the repository at this point in the history
…ks and handle `None`s. (#784)

* Add `end_with_user` flag

* Add `include_system_prompt` attribute to `Magpie`

* Update docstrings

* Update `MagpieBase` to handle `None`s

* Fix `InferenceEndpointsLLM` unit tests after release of
`huggingface_hub==0.24.0`
  • Loading branch information
gabrielmbmb authored Jul 19, 2024
1 parent b22a494 commit 0ef3f70
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 36 deletions.
119 changes: 90 additions & 29 deletions src/distilabel/steps/tasks/magpie/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

from pydantic import Field, PositiveInt

Expand All @@ -26,7 +26,7 @@
from distilabel.steps.tasks.base import Task

if TYPE_CHECKING:
from distilabel.steps.tasks.typing import ChatType, FormattedInput
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.typing import StepOutput

MAGPIE_MULTI_TURN_SYSTEM_PROMPT = (
Expand All @@ -50,6 +50,14 @@ class MagpieBase(RuntimeParametersMixin):
default=1,
description="The number of turns to generate for the conversation.",
)
end_with_user: RuntimeParameter[bool] = Field(
default=False,
description="Whether the conversation should end with a user message.",
)
include_system_prompt: RuntimeParameter[bool] = Field(
default=False,
description="Whether to include the system prompt used in the generated conversation.",
)
only_instruction: RuntimeParameter[bool] = Field(
default=False,
description="Whether to generate only the instruction. If this argument"
Expand All @@ -63,7 +71,7 @@ class MagpieBase(RuntimeParametersMixin):

def _prepare_inputs_for_instruction_generation(
self, inputs: List[Dict[str, Any]]
) -> List["FormattedInput"]:
) -> List["ChatType"]:
"""Prepares the inputs adding the system (if required) prompt provided in each row,
or if the conversations to generate have more than one turn, then adding the system
prompt for multi-turn conversation from the paper.
Expand Down Expand Up @@ -106,7 +114,8 @@ def _append_messages_to_conversations(
The updated conversations.
"""
for instruction, conversation in zip(messages, conversations):
conversation.append({"role": role, "content": instruction})
if instruction is not None:
conversation.append({"role": role, "content": instruction})
return conversations

def _generate_instruction(
Expand All @@ -120,41 +129,83 @@ def _generate_instruction(
)
return [{"instruction": output[0]} for output in outputs]

def _prepare_conversation_outputs(
self, conversations: List["ChatType"]
) -> List[Dict[str, Any]]:
"""Prepare the output conversation removing the system prompt if necessary.
Args:
conversations: the list of generated conversations.
Returns:
A list of dictionaries containing a "conversation" key.
"""
outputs = []
for conversation in conversations:
if not self.include_system_prompt and conversation[0]["role"] == "system":
conversation.pop(0)
outputs.append({"conversation": conversation})
return outputs

def _generate_conversation_turn(
self, role: str, conversations: List["ChatType"], active_indices: List[int]
) -> Tuple[List["ChatType"], List[int]]:
# Generate an output for the conversations that are still active (no previous `None`s)
outputs = self.llm.generate(
inputs=[conversations[idx] for idx in active_indices],
num_generations=1,
**self.llm.generation_kwargs, # type: ignore
)

active_conversations = [conversations[idx] for idx in active_indices]
updated_conversations = self._append_messages_to_conversations(
role=role,
messages=[output[0] for output in outputs],
conversations=active_conversations,
)

for idx, conv in zip(active_indices, updated_conversations):
conversations[idx] = conv

new_active_indices = [
idx for idx, output in zip(active_indices, outputs) if output[0] is not None
]

return conversations, new_active_indices

def _generate_multi_turn_conversation(
self, inputs: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
conversations = self._prepare_inputs_for_instruction_generation(inputs)

for _ in range(self.n_turns): # type: ignore
# Generate instruction or user message
outputs = self.llm.generate(
inputs=conversations,
num_generations=1,
**self.llm.generation_kwargs, # type: ignore
)
conversations: List["ChatType"] = (
self._prepare_inputs_for_instruction_generation(inputs)
)
# Keep track of the active conversations, as it could happen that for some conversation
# we can't generate the next turn because the `LLM` returned `None`.
active_indices = list(range(len(conversations)))

for i in range(self.n_turns): # type: ignore
if not active_indices:
break

conversations = self._append_messages_to_conversations(
role="user",
messages=[output[0] for output in outputs],
conversations=conversations, # type: ignore
# Generate user message
conversations, active_indices = self._generate_conversation_turn(
role="user", conversations=conversations, active_indices=active_indices
)

# TODO: handle potential previous `None`s
if i == self.n_turns - 1 and self.end_with_user: # type: ignore
break

# Generate response
outputs = self.llm.generate(
inputs=conversations,
num_generations=1,
**self.llm.generation_kwargs, # type: ignore
)
if not active_indices:
break

conversations = self._append_messages_to_conversations(
# Generate assistant message
conversations, active_indices = self._generate_conversation_turn(
role="assistant",
messages=[output[0] for output in outputs],
conversations=conversations, # type: ignore
conversations=conversations,
active_indices=active_indices,
)

return [{"conversation": conversation} for conversation in conversations]
return self._prepare_conversation_outputs(conversations)

def _generate_with_pre_query_template(
self, inputs: List[Dict[str, Any]]
Expand Down Expand Up @@ -196,6 +247,11 @@ class Magpie(Task, MagpieBase):
Attributes:
n_turns: the number of turns that the generated conversation will have.
Defaults to `1`.
end_with_user: whether the conversation should end with a user message.
Defaults to `False`.
include_system_prompt: whether to include the system prompt used in the generated
conversation. Defaults to `False`.
only_instruction: whether to generate only the instruction. If this argument is
`True`, then `n_turns` will be ignored. Defaults to `False`.
system_prompt: an optional system prompt that can be used to steer the LLM to generate
Expand All @@ -204,7 +260,12 @@ class Magpie(Task, MagpieBase):
one from the column will be used. Defaults to `None`.
Runtime parameters:
- `n_turns`: the number of turns that the generated conversation will have.
- `n_turns`: the number of turns that the generated conversation will have. Defaults
to `1`.
- `end_with_user`: whether the conversation should end with a user message.
Defaults to `False`.
- `include_system_prompt`: whether to include the system prompt used in the generated
conversation. Defaults to `False`.
- `only_instruction`: whether to generate only the instruction. If this argument is
`True`, then `n_turns` will be ignored. Defaults to `False`.
- `system_prompt`: an optional system prompt that can be used to steer the LLM to
Expand Down
20 changes: 16 additions & 4 deletions src/distilabel/steps/tasks/magpie/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ class MagpieGenerator(GeneratorTask, MagpieBase):
Attributes:
n_turns: the number of turns that the generated conversation will have.
Defaults to `1`.
end_with_user: whether the conversation should end with a user message.
Defaults to `False`.
include_system_prompt: whether to include the system prompt used in the generated
conversation. Defaults to `False`.
only_instruction: whether to generate only the instruction. If this argument is
`True`, then `n_turns` will be ignored. Defaults to `False`.
system_prompt: an optional system prompt that can be used to steer the LLM to generate
Expand All @@ -51,11 +56,18 @@ class MagpieGenerator(GeneratorTask, MagpieBase):
num_rows: the number of rows to be generated.
Runtime parameters:
- `n_turns`: the number of turns that the generated conversation will have.
- `only_instruction`: whether to generate only the instruction. If this argument
is `True`, then `n_turns` will be ignored. Defaults to `False`.
- `n_turns`: the number of turns that the generated conversation will have. Defaults
to `1`.
- `end_with_user`: whether the conversation should end with a user message.
Defaults to `False`.
- `include_system_prompt`: whether to include the system prompt used in the generated
conversation. Defaults to `False`.
- `only_instruction`: whether to generate only the instruction. If this argument is
`True`, then `n_turns` will be ignored. Defaults to `False`.
- `system_prompt`: an optional system prompt that can be used to steer the LLM to
generate content of certain topic, guide the style, etc. Defaults to `None`.
generate content of certain topic, guide the style, etc. If the provided inputs
contains a `system_prompt` column, then this runtime parameter will be ignored
and the one from the column will be used. Defaults to `None`.
- `num_rows`: the number of rows to be generated.
Output columns:
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/llms/huggingface/test_inference_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ async def test_agenerate_with_chat_completion(
created=1721045246,
id="",
model="meta-llama/Meta-Llama-3-70B-Instruct",
object="chat.completion",
system_fingerprint="2.1.1-dev0-sha-4327210",
usage=ChatCompletionOutputUsage(
completion_tokens=66, prompt_tokens=18, total_tokens=84
Expand Down Expand Up @@ -212,7 +211,6 @@ async def test_agenerate_with_chat_completion_fails(
created=1721045246,
id="",
model="meta-llama/Meta-Llama-3-70B-Instruct",
object="chat.completion",
system_fingerprint="2.1.1-dev0-sha-4327210",
usage=ChatCompletionOutputUsage(
completion_tokens=66, prompt_tokens=18, total_tokens=84
Expand Down
Loading

0 comments on commit 0ef3f70

Please sign in to comment.