Skip to content

Commit

Permalink
Update system_prompt so it can be also a list
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Jul 26, 2024
1 parent 61d8de1 commit 5657ffd
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 24 deletions.
33 changes: 21 additions & 12 deletions src/distilabel/steps/tasks/magpie/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import random
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

from pydantic import Field, PositiveInt
Expand Down Expand Up @@ -63,10 +64,10 @@ class MagpieBase(RuntimeParametersMixin):
description="Whether to generate only the instruction. If this argument"
" is `True`, then `n_turns` will be ignored.",
)
system_prompt: Optional[RuntimeParameter[str]] = Field(
system_prompt: Optional[RuntimeParameter[Union[List[str], str]]] = Field(
default=None,
description="An optional system prompt that can be used to steer the LLM to generate"
" content of certain topic, guide the style, etc.",
description="An optional system prompt or list of system prompts that can be used"
" to steer the LLM to generate content of certain topic, guide the style, etc.",
)

def _prepare_inputs_for_instruction_generation(
Expand All @@ -90,7 +91,11 @@ def _prepare_inputs_for_instruction_generation(
{"role": "system", "content": input["system_prompt"]}
)
elif self.system_prompt is not None:
conversation.append({"role": "system", "content": self.system_prompt})
if isinstance(self.system_prompt, list):
system_prompt = random.choice(self.system_prompt)
else:
system_prompt = self.system_prompt
conversation.append({"role": "system", "content": system_prompt})
elif self.n_turns > 1: # type: ignore
conversation.append(
{"role": "system", "content": MAGPIE_MULTI_TURN_SYSTEM_PROMPT}
Expand Down Expand Up @@ -265,10 +270,12 @@ class Magpie(Task, MagpieBase):
conversation. Defaults to `False`.
only_instruction: whether to generate only the instruction. If this argument is
`True`, then `n_turns` will be ignored. Defaults to `False`.
system_prompt: an optional system prompt that can be used to steer the LLM to generate
content of certain topic, guide the style, etc. If the provided inputs contains
a `system_prompt` column, then this runtime parameter will be ignored and the
one from the column will be used. Defaults to `None`.
system_prompt: an optional system prompt or list of system prompts that can
be used to steer the LLM to generate content of certain topic, guide the style,
etc. If it's a list of system prompts, then a random system prompt will be chosen
per input/output batch. If the provided inputs contains a `system_prompt` column,
then this runtime parameter will be ignored and the one from the column will
be used. Defaults to `None`.
Runtime parameters:
- `n_turns`: the number of turns that the generated conversation will have. Defaults
Expand All @@ -279,10 +286,12 @@ class Magpie(Task, MagpieBase):
conversation. Defaults to `False`.
- `only_instruction`: whether to generate only the instruction. If this argument is
`True`, then `n_turns` will be ignored. Defaults to `False`.
- `system_prompt`: an optional system prompt that can be used to steer the LLM to
generate content of certain topic, guide the style, etc. If the provided inputs
contains a `system_prompt` column, then this runtime parameter will be ignored
and the one from the column will be used. Defaults to `None`.
- `system_prompt`: an optional system prompt or list of system prompts that can
be used to steer the LLM to generate content of certain topic, guide the style,
etc. If it's a list of system prompts, then a random system prompt will be chosen
per input/output batch. If the provided inputs contains a `system_prompt` column,
then this runtime parameter will be ignored and the one from the column will
be used. Defaults to `None`.
Input columns:
- system_prompt (`str`, optional): an optional system prompt that can be provided
Expand Down
20 changes: 12 additions & 8 deletions src/distilabel/steps/tasks/magpie/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ class MagpieGenerator(GeneratorTask, MagpieBase):
conversation. Defaults to `False`.
only_instruction: whether to generate only the instruction. If this argument is
`True`, then `n_turns` will be ignored. Defaults to `False`.
system_prompt: an optional system prompt that can be used to steer the LLM to generate
content of certain topic, guide the style, etc. If the provided inputs contains
a `system_prompt` column, then this runtime parameter will be ignored and the
one from the column will be used. Defaults to `None`.
system_prompt: an optional system prompt or list of system prompts that can
be used to steer the LLM to generate content of certain topic, guide the style,
etc. If it's a list of system prompts, then a random system prompt will be chosen
per input/output batch. If the provided inputs contains a `system_prompt` column,
then this runtime parameter will be ignored and the one from the column will
be used. Defaults to `None`.
num_rows: the number of rows to be generated.
Runtime parameters:
Expand All @@ -64,10 +66,12 @@ class MagpieGenerator(GeneratorTask, MagpieBase):
conversation. Defaults to `False`.
- `only_instruction`: whether to generate only the instruction. If this argument is
`True`, then `n_turns` will be ignored. Defaults to `False`.
- `system_prompt`: an optional system prompt that can be used to steer the LLM to
generate content of certain topic, guide the style, etc. If the provided inputs
contains a `system_prompt` column, then this runtime parameter will be ignored
and the one from the column will be used. Defaults to `None`.
- `system_prompt`: an optional system prompt or list of system prompts that can
be used to steer the LLM to generate content of certain topic, guide the style,
etc. If it's a list of system prompts, then a random system prompt will be chosen
per input/output batch. If the provided inputs contains a `system_prompt` column,
then this runtime parameter will be ignored and the one from the column will
be used. Defaults to `None`.
- `num_rows`: the number of rows to be generated.
Output columns:
Expand Down
94 changes: 92 additions & 2 deletions tests/unit/steps/tasks/magpie/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import random
from unittest import mock

import pytest
Expand All @@ -30,7 +31,11 @@ def test_raise_value_error_llm_no_magpie_mixin(self) -> None:
Magpie(llm=OpenAILLM(model="gpt-4", api_key="fake")) # type: ignore

def test_outputs(self) -> None:
task = Magpie(llm=DummyMagpieLLM(magpie_pre_query_template="llama3"))
task = Magpie(llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), n_turns=1)

assert task.outputs == ["instruction", "response", "model_name"]

task = Magpie(llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), n_turns=2)

assert task.outputs == ["conversation", "model_name"]

Expand All @@ -46,23 +51,108 @@ def test_process(self) -> None:

task.load()

assert next(task.process(inputs=[{}, {}, {}])) == [
{
"instruction": "Hello Magpie",
"response": "Hello Magpie",
"model_name": "test",
},
{
"instruction": "Hello Magpie",
"response": "Hello Magpie",
"model_name": "test",
},
{
"instruction": "Hello Magpie",
"response": "Hello Magpie",
"model_name": "test",
},
]

def test_process_with_system_prompt(self) -> None:
task = Magpie(
llm=DummyMagpieLLM(magpie_pre_query_template="llama3"),
n_turns=2,
system_prompt="This is a system prompt.",
include_system_prompt=True,
)

task.load()

assert next(task.process(inputs=[{}, {}, {}])) == [
{
"conversation": [
{"role": "system", "content": "This is a system prompt."},
{"role": "user", "content": "Hello Magpie"},
{"role": "assistant", "content": "Hello Magpie"},
{"role": "user", "content": "Hello Magpie"},
{"role": "assistant", "content": "Hello Magpie"},
],
"model_name": "test",
},
{
"conversation": [
{"role": "system", "content": "This is a system prompt."},
{"role": "user", "content": "Hello Magpie"},
{"role": "assistant", "content": "Hello Magpie"},
{"role": "user", "content": "Hello Magpie"},
{"role": "assistant", "content": "Hello Magpie"},
],
"model_name": "test",
},
{
"conversation": [
{"role": "system", "content": "This is a system prompt."},
{"role": "user", "content": "Hello Magpie"},
{"role": "assistant", "content": "Hello Magpie"},
{"role": "user", "content": "Hello Magpie"},
{"role": "assistant", "content": "Hello Magpie"},
],
"model_name": "test",
},
]

def test_process_with_several_system_prompts(self) -> None:
task = Magpie(
llm=DummyMagpieLLM(magpie_pre_query_template="llama3"),
n_turns=2,
system_prompt=[
"This is a system prompt.",
"This is another system prompt.",
],
include_system_prompt=True,
)

random.seed(42)

task.load()

assert next(task.process(inputs=[{}, {}, {}])) == [
{
"conversation": [
{"role": "system", "content": "This is a system prompt."},
{"role": "user", "content": "Hello Magpie"},
{"role": "assistant", "content": "Hello Magpie"},
{"role": "user", "content": "Hello Magpie"},
{"role": "assistant", "content": "Hello Magpie"},
],
"model_name": "test",
},
{
"conversation": [
{"role": "system", "content": "This is a system prompt."},
{"role": "user", "content": "Hello Magpie"},
{"role": "assistant", "content": "Hello Magpie"},
{"role": "user", "content": "Hello Magpie"},
{"role": "assistant", "content": "Hello Magpie"},
],
"model_name": "test",
},
{
"conversation": [
{"role": "system", "content": "This is another system prompt."},
{"role": "user", "content": "Hello Magpie"},
{"role": "assistant", "content": "Hello Magpie"},
{"role": "user", "content": "Hello Magpie"},
{"role": "assistant", "content": "Hello Magpie"},
],
Expand Down Expand Up @@ -367,7 +457,7 @@ def test_serialization(self) -> None:
{
"name": "system_prompt",
"optional": True,
"description": "An optional system prompt that can be used to steer the LLM to generate content of certain topic, guide the style, etc.",
"description": "An optional system prompt or list of system prompts that can be used to steer the LLM to generate content of certain topic, guide the style, etc.",
},
{
"name": "resources",
Expand Down
12 changes: 10 additions & 2 deletions tests/unit/steps/tasks/magpie/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,15 @@ def test_raise_value_error_llm_no_magpie_mixin(self) -> None:
MagpieGenerator(llm=OpenAILLM(model="gpt-4", api_key="fake")) # type: ignore

def test_outputs(self) -> None:
task = MagpieGenerator(llm=DummyMagpieLLM(magpie_pre_query_template="llama3"))
task = MagpieGenerator(
llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), n_turns=1
)

assert task.outputs == ["instruction", "response", "model_name"]

task = MagpieGenerator(
llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), n_turns=2
)

assert task.outputs == ["conversation", "model_name"]

Expand Down Expand Up @@ -106,7 +114,7 @@ def test_serialization(self) -> None:
{
"name": "system_prompt",
"optional": True,
"description": "An optional system prompt that can be used to steer the LLM to generate content of certain topic, guide the style, etc.",
"description": "An optional system prompt or list of system prompts that can be used to steer the LLM to generate content of certain topic, guide the style, etc.",
},
{
"name": "resources",
Expand Down

0 comments on commit 5657ffd

Please sign in to comment.