diff --git a/src/distilabel/steps/tasks/magpie/base.py b/src/distilabel/steps/tasks/magpie/base.py index d154e668a5..672fcbf5ab 100644 --- a/src/distilabel/steps/tasks/magpie/base.py +++ b/src/distilabel/steps/tasks/magpie/base.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from pydantic import Field, PositiveInt @@ -26,7 +26,7 @@ from distilabel.steps.tasks.base import Task if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType, FormattedInput + from distilabel.steps.tasks.typing import ChatType from distilabel.steps.typing import StepOutput MAGPIE_MULTI_TURN_SYSTEM_PROMPT = ( @@ -50,6 +50,14 @@ class MagpieBase(RuntimeParametersMixin): default=1, description="The number of turns to generate for the conversation.", ) + end_with_user: RuntimeParameter[bool] = Field( + default=False, + description="Whether the conversation should end with a user message.", + ) + include_system_prompt: RuntimeParameter[bool] = Field( + default=False, + description="Whether to include the system prompt used in the generated conversation.", + ) only_instruction: RuntimeParameter[bool] = Field( default=False, description="Whether to generate only the instruction. If this argument" @@ -63,7 +71,7 @@ class MagpieBase(RuntimeParametersMixin): def _prepare_inputs_for_instruction_generation( self, inputs: List[Dict[str, Any]] - ) -> List["FormattedInput"]: + ) -> List["ChatType"]: """Prepares the inputs adding the system (if required) prompt provided in each row, or if the conversations to generate have more than one turn, then adding the system prompt for multi-turn conversation from the paper. @@ -106,7 +114,8 @@ def _append_messages_to_conversations( The updated conversations. """ for instruction, conversation in zip(messages, conversations): - conversation.append({"role": role, "content": instruction}) + if instruction is not None: + conversation.append({"role": role, "content": instruction}) return conversations def _generate_instruction( @@ -120,41 +129,83 @@ def _generate_instruction( ) return [{"instruction": output[0]} for output in outputs] + def _prepare_conversation_outputs( + self, conversations: List["ChatType"] + ) -> List[Dict[str, Any]]: + """Prepare the output conversation removing the system prompt if necessary. + + Args: + conversations: the list of generated conversations. + + Returns: + A list of dictionaries containing a "conversation" key. + """ + outputs = [] + for conversation in conversations: + if not self.include_system_prompt and conversation[0]["role"] == "system": + conversation.pop(0) + outputs.append({"conversation": conversation}) + return outputs + + def _generate_conversation_turn( + self, role: str, conversations: List["ChatType"], active_indices: List[int] + ) -> Tuple[List["ChatType"], List[int]]: + # Generate an output for the conversations that are still active (no previous `None`s) + outputs = self.llm.generate( + inputs=[conversations[idx] for idx in active_indices], + num_generations=1, + **self.llm.generation_kwargs, # type: ignore + ) + + active_conversations = [conversations[idx] for idx in active_indices] + updated_conversations = self._append_messages_to_conversations( + role=role, + messages=[output[0] for output in outputs], + conversations=active_conversations, + ) + + for idx, conv in zip(active_indices, updated_conversations): + conversations[idx] = conv + + new_active_indices = [ + idx for idx, output in zip(active_indices, outputs) if output[0] is not None + ] + + return conversations, new_active_indices + def _generate_multi_turn_conversation( self, inputs: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: - conversations = self._prepare_inputs_for_instruction_generation(inputs) - - for _ in range(self.n_turns): # type: ignore - # Generate instruction or user message - outputs = self.llm.generate( - inputs=conversations, - num_generations=1, - **self.llm.generation_kwargs, # type: ignore - ) + conversations: List["ChatType"] = ( + self._prepare_inputs_for_instruction_generation(inputs) + ) + # Keep track of the active conversations, as it could happen that for some conversation + # we can't generate the next turn because the `LLM` returned `None`. + active_indices = list(range(len(conversations))) + + for i in range(self.n_turns): # type: ignore + if not active_indices: + break - conversations = self._append_messages_to_conversations( - role="user", - messages=[output[0] for output in outputs], - conversations=conversations, # type: ignore + # Generate user message + conversations, active_indices = self._generate_conversation_turn( + role="user", conversations=conversations, active_indices=active_indices ) - # TODO: handle potential previous `None`s + if i == self.n_turns - 1 and self.end_with_user: # type: ignore + break - # Generate response - outputs = self.llm.generate( - inputs=conversations, - num_generations=1, - **self.llm.generation_kwargs, # type: ignore - ) + if not active_indices: + break - conversations = self._append_messages_to_conversations( + # Generate assistant message + conversations, active_indices = self._generate_conversation_turn( role="assistant", - messages=[output[0] for output in outputs], - conversations=conversations, # type: ignore + conversations=conversations, + active_indices=active_indices, ) - return [{"conversation": conversation} for conversation in conversations] + return self._prepare_conversation_outputs(conversations) def _generate_with_pre_query_template( self, inputs: List[Dict[str, Any]] @@ -196,6 +247,11 @@ class Magpie(Task, MagpieBase): Attributes: n_turns: the number of turns that the generated conversation will have. + Defaults to `1`. + end_with_user: whether the conversation should end with a user message. + Defaults to `False`. + include_system_prompt: whether to include the system prompt used in the generated + conversation. Defaults to `False`. only_instruction: whether to generate only the instruction. If this argument is `True`, then `n_turns` will be ignored. Defaults to `False`. system_prompt: an optional system prompt that can be used to steer the LLM to generate @@ -204,7 +260,12 @@ class Magpie(Task, MagpieBase): one from the column will be used. Defaults to `None`. Runtime parameters: - - `n_turns`: the number of turns that the generated conversation will have. + - `n_turns`: the number of turns that the generated conversation will have. Defaults + to `1`. + - `end_with_user`: whether the conversation should end with a user message. + Defaults to `False`. + - `include_system_prompt`: whether to include the system prompt used in the generated + conversation. Defaults to `False`. - `only_instruction`: whether to generate only the instruction. If this argument is `True`, then `n_turns` will be ignored. Defaults to `False`. - `system_prompt`: an optional system prompt that can be used to steer the LLM to diff --git a/src/distilabel/steps/tasks/magpie/generator.py b/src/distilabel/steps/tasks/magpie/generator.py index 34b18a015e..97ae6be45e 100644 --- a/src/distilabel/steps/tasks/magpie/generator.py +++ b/src/distilabel/steps/tasks/magpie/generator.py @@ -42,6 +42,11 @@ class MagpieGenerator(GeneratorTask, MagpieBase): Attributes: n_turns: the number of turns that the generated conversation will have. + Defaults to `1`. + end_with_user: whether the conversation should end with a user message. + Defaults to `False`. + include_system_prompt: whether to include the system prompt used in the generated + conversation. Defaults to `False`. only_instruction: whether to generate only the instruction. If this argument is `True`, then `n_turns` will be ignored. Defaults to `False`. system_prompt: an optional system prompt that can be used to steer the LLM to generate @@ -51,11 +56,18 @@ class MagpieGenerator(GeneratorTask, MagpieBase): num_rows: the number of rows to be generated. Runtime parameters: - - `n_turns`: the number of turns that the generated conversation will have. - - `only_instruction`: whether to generate only the instruction. If this argument - is `True`, then `n_turns` will be ignored. Defaults to `False`. + - `n_turns`: the number of turns that the generated conversation will have. Defaults + to `1`. + - `end_with_user`: whether the conversation should end with a user message. + Defaults to `False`. + - `include_system_prompt`: whether to include the system prompt used in the generated + conversation. Defaults to `False`. + - `only_instruction`: whether to generate only the instruction. If this argument is + `True`, then `n_turns` will be ignored. Defaults to `False`. - `system_prompt`: an optional system prompt that can be used to steer the LLM to - generate content of certain topic, guide the style, etc. Defaults to `None`. + generate content of certain topic, guide the style, etc. If the provided inputs + contains a `system_prompt` column, then this runtime parameter will be ignored + and the one from the column will be used. Defaults to `None`. - `num_rows`: the number of rows to be generated. Output columns: diff --git a/tests/unit/llms/huggingface/test_inference_endpoints.py b/tests/unit/llms/huggingface/test_inference_endpoints.py index 436815b0e5..cd45c91153 100644 --- a/tests/unit/llms/huggingface/test_inference_endpoints.py +++ b/tests/unit/llms/huggingface/test_inference_endpoints.py @@ -171,7 +171,6 @@ async def test_agenerate_with_chat_completion( created=1721045246, id="", model="meta-llama/Meta-Llama-3-70B-Instruct", - object="chat.completion", system_fingerprint="2.1.1-dev0-sha-4327210", usage=ChatCompletionOutputUsage( completion_tokens=66, prompt_tokens=18, total_tokens=84 @@ -212,7 +211,6 @@ async def test_agenerate_with_chat_completion_fails( created=1721045246, id="", model="meta-llama/Meta-Llama-3-70B-Instruct", - object="chat.completion", system_fingerprint="2.1.1-dev0-sha-4327210", usage=ChatCompletionOutputUsage( completion_tokens=66, prompt_tokens=18, total_tokens=84 diff --git a/tests/unit/steps/tasks/magpie/test_base.py b/tests/unit/steps/tasks/magpie/test_base.py index 885b129b88..55973ef8a1 100644 --- a/tests/unit/steps/tasks/magpie/test_base.py +++ b/tests/unit/steps/tasks/magpie/test_base.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock + import pytest from distilabel.llms.openai import OpenAILLM from distilabel.steps.tasks.magpie.base import MAGPIE_MULTI_TURN_SYSTEM_PROMPT, Magpie @@ -68,11 +70,125 @@ def test_process(self) -> None: }, ] + def test_process_failing_generation_for_some_rows(self) -> None: + with mock.patch( + "tests.unit.conftest.DummyMagpieLLM.generate", + side_effect=[ + [["Hello Magpie"], [None], ["Hello Magpie"]], + [["Hello Magpie"], ["Hello Magpie"]], + [["Hello Magpie"], [None]], + [["Hello Magpie"]], + ], + ): + task = Magpie( + llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), n_turns=2 + ) + + task.load() + + assert next(task.process(inputs=[{}, {}, {}])) == [ + { + "conversation": [ + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + ], + "model_name": "test", + }, + { + "conversation": [], + "model_name": "test", + }, + { + "conversation": [ + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + ], + "model_name": "test", + }, + ] + def test_process_with_n_turns(self) -> None: task = Magpie(llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), n_turns=2) task.load() + assert next(task.process(inputs=[{}, {}, {}])) == [ + { + "conversation": [ + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + ], + "model_name": "test", + }, + { + "conversation": [ + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + ], + "model_name": "test", + }, + { + "conversation": [ + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + ], + "model_name": "test", + }, + ] + + def test_process_with_end_with_user(self) -> None: + task = Magpie( + llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), + n_turns=2, + end_with_user=True, + ) + + task.load() + + assert next(task.process(inputs=[{}, {}, {}])) == [ + { + "conversation": [ + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + {"role": "user", "content": "Hello Magpie"}, + ], + "model_name": "test", + }, + { + "conversation": [ + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + {"role": "user", "content": "Hello Magpie"}, + ], + "model_name": "test", + }, + { + "conversation": [ + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + {"role": "user", "content": "Hello Magpie"}, + ], + "model_name": "test", + }, + ] + + def test_process_with_include_system_prompt(self) -> None: + task = Magpie( + llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), + n_turns=2, + include_system_prompt=True, + ) + + task.load() + assert next(task.process(inputs=[{}, {}, {}])) == [ { "conversation": [ @@ -107,7 +223,11 @@ def test_process_with_n_turns(self) -> None: ] def test_process_with_system_prompt_per_row(self) -> None: - task = Magpie(llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), n_turns=2) + task = Magpie( + llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), + n_turns=2, + include_system_prompt=True, + ) task.load() @@ -195,6 +315,8 @@ def test_serialization(self) -> None: }, }, "n_turns": 1, + "end_with_user": False, + "include_system_prompt": False, "only_instruction": True, "system_prompt": None, "name": "magpie_0", @@ -227,6 +349,16 @@ def test_serialization(self) -> None: "optional": True, "description": "The number of turns to generate for the conversation.", }, + { + "name": "end_with_user", + "optional": True, + "description": "Whether the conversation should end with a user message.", + }, + { + "name": "include_system_prompt", + "optional": True, + "description": "Whether to include the system prompt used in the generated conversation.", + }, { "name": "only_instruction", "optional": True, diff --git a/tests/unit/steps/tasks/magpie/test_generator.py b/tests/unit/steps/tasks/magpie/test_generator.py index d10638acd7..c770872db0 100644 --- a/tests/unit/steps/tasks/magpie/test_generator.py +++ b/tests/unit/steps/tasks/magpie/test_generator.py @@ -53,6 +53,8 @@ def test_serialization(self) -> None: }, }, "n_turns": 1, + "end_with_user": False, + "include_system_prompt": False, "only_instruction": False, "system_prompt": None, "name": "magpie_generator_0", @@ -86,6 +88,16 @@ def test_serialization(self) -> None: "optional": True, "description": "The number of turns to generate for the conversation.", }, + { + "name": "end_with_user", + "optional": True, + "description": "Whether the conversation should end with a user message.", + }, + { + "name": "include_system_prompt", + "optional": True, + "description": "Whether to include the system prompt used in the generated conversation.", + }, { "name": "only_instruction", "optional": True,