From 9daf07474f0160ecd107453da6f2bb1747d4a811 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Wed, 21 Aug 2024 16:37:49 +0200 Subject: [PATCH 1/7] Initial work for `URIAL` --- src/distilabel/steps/tasks/__init__.py | 2 + .../steps/tasks/templates/urial.jinja2 | 15 +++++ src/distilabel/steps/tasks/text_generation.py | 11 ++-- src/distilabel/steps/tasks/ultrafeedback.py | 15 ++--- src/distilabel/steps/tasks/urial.py | 64 +++++++++++++++++++ 5 files changed, 94 insertions(+), 13 deletions(-) create mode 100644 src/distilabel/steps/tasks/templates/urial.jinja2 create mode 100644 src/distilabel/steps/tasks/urial.py diff --git a/src/distilabel/steps/tasks/__init__.py b/src/distilabel/steps/tasks/__init__.py index 0b3a69596b..7bd96c3ce0 100644 --- a/src/distilabel/steps/tasks/__init__.py +++ b/src/distilabel/steps/tasks/__init__.py @@ -46,6 +46,7 @@ from distilabel.steps.tasks.text_generation import ChatGeneration, TextGeneration from distilabel.steps.tasks.typing import ChatItem, ChatType from distilabel.steps.tasks.ultrafeedback import UltraFeedback +from distilabel.steps.tasks.urial import URIAL __all__ = [ "GeneratorTask", @@ -79,4 +80,5 @@ "ChatItem", "ChatType", "UltraFeedback", + "URIAL", ] diff --git a/src/distilabel/steps/tasks/templates/urial.jinja2 b/src/distilabel/steps/tasks/templates/urial.jinja2 new file mode 100644 index 0000000000..123d3acbb0 --- /dev/null +++ b/src/distilabel/steps/tasks/templates/urial.jinja2 @@ -0,0 +1,15 @@ +# Instruction + +Below is a list of conversations between a human and an AI assistant (you). +Users place their queries under "# User:", and your responses are under "# Assistant:". +You are a helpful, respectful, and honest assistant. +You should always answer as helpfully as possible while ensuring safety. +Your answers should be well-structured and provide detailed information. They should also have an engaging tone. +Your responses must not contain any fake, harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, even if it may be helpful. +Your response must be socially responsible, and thus you can refuse to answer some controversial topics. + +{% for message in messages %} +# {{ message.role | capitalize }} + +{{ message.content }} +{% endfor %} diff --git a/src/distilabel/steps/tasks/text_generation.py b/src/distilabel/steps/tasks/text_generation.py index f5c4659651..aa248c700b 100644 --- a/src/distilabel/steps/tasks/text_generation.py +++ b/src/distilabel/steps/tasks/text_generation.py @@ -13,12 +13,15 @@ # limitations under the License. import warnings -from typing import Any, Dict, List, Union +from typing import TYPE_CHECKING, Any, Dict, List, Union from distilabel.steps.tasks.base import Task -from distilabel.steps.tasks.typing import ChatType from distilabel.utils.chat import is_openai_format +if TYPE_CHECKING: + from distilabel.steps.tasks.typing import ChatType + from distilabel.steps.typing import StepColumns + class TextGeneration(Task): """Simple text generation with an `LLM` given an instruction. @@ -78,11 +81,11 @@ class TextGeneration(Task): use_system_prompt: bool = True @property - def inputs(self) -> List[str]: + def inputs(self) -> "StepColumns": """The input for the task is the `instruction`.""" return ["instruction"] - def format_input(self, input: Dict[str, Any]) -> ChatType: + def format_input(self, input: Dict[str, Any]) -> "ChatType": """The input is formatted as a `ChatType` assuming that the instruction is the first interaction from the user within a conversation.""" diff --git a/src/distilabel/steps/tasks/ultrafeedback.py b/src/distilabel/steps/tasks/ultrafeedback.py index eec232aabd..dae68bb48f 100644 --- a/src/distilabel/steps/tasks/ultrafeedback.py +++ b/src/distilabel/steps/tasks/ultrafeedback.py @@ -12,14 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib.resources as importlib_resources import re -import sys - -if sys.version_info < (3, 9): - import importlib_resources -else: - import importlib.resources as importlib_resources - from typing import Any, Dict, List, Literal, Optional, Union import orjson @@ -264,7 +258,7 @@ def outputs(self) -> List[str]: return columns + ["model_name"] def format_output( - self, output: Union[str, None], input: Dict[str, Any] + self, output: Union[str, None], input: Union[Dict[str, Any], None] = None ) -> Dict[str, Any]: """The output is formatted as a dictionary with the `ratings` and `rationales` for each of the provided `generations` for the given `instruction`. The `model_name` @@ -281,12 +275,15 @@ def format_output( `ratings`, and `rationales-for-ratings` for each of the provided `generations` for the given `instruction` if the provided aspect is either `helpfulness` or `truthfulness`. """ + assert input is not None, "Input is required to format the output." + if self.aspect in [ "honesty", "instruction-following", "overall-rating", ]: return self._format_ratings_rationales_output(output, input) + return self._format_types_ratings_rationales_output(output, input) def _format_ratings_rationales_output( @@ -450,7 +447,7 @@ class SchemaUltraFeedbackWithType(BaseModel): def _format_structured_output( self, output: str, input: Dict[str, Any] - ) -> Dict[str, str]: + ) -> Dict[str, Any]: """Parses the structured response, which should correspond to a dictionary with either `positive`, or `positive` and `negative` keys. diff --git a/src/distilabel/steps/tasks/urial.py b/src/distilabel/steps/tasks/urial.py new file mode 100644 index 0000000000..8ae8b40518 --- /dev/null +++ b/src/distilabel/steps/tasks/urial.py @@ -0,0 +1,64 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.resources as importlib_resources +from typing import TYPE_CHECKING, Any, Dict, Union + +from jinja2 import Template + +from distilabel.steps.tasks import Task + +if TYPE_CHECKING: + from distilabel.steps.tasks.typing import ChatType + from distilabel.steps.typing import StepColumns + + +class URIAL(Task): + def load(self) -> None: + """Loads the Jinja2 template for the given `aspect`.""" + super().load() + + _path = str( + importlib_resources.files("distilabel") + / "steps" + / "tasks" + / "templates" + / "ultrafeedback" + / "urial.jinja2" + ) + + self._template = Template(open(_path).read()) + + @property + def inputs(self) -> "StepColumns": + return {"instruction": False, "conversation": False} + + def format_input(self, input: Dict[str, Any]) -> "ChatType": + messages = ( + [{"role": "user", "content": input["instruction"]}] + if "instruction" in input + else input["conversation"] + ) + return [{"role": "user", "content": self._template.render(messages=messages)}] + + @property + def outputs(self) -> "StepColumns": + return ["generation", "model_name"] + + def format_output( + self, output: Union[str, None], input: Union[Dict[str, Any], None] = None + ) -> Dict[str, Any]: + if output is None: + return {} + pass From c36632e6ff58a85e5648f3dbcf8ebf8e236e96d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Wed, 21 Aug 2024 16:46:53 +0200 Subject: [PATCH 2/7] Update template --- src/distilabel/steps/tasks/templates/urial.jinja2 | 4 +++- src/distilabel/steps/tasks/urial.py | 4 ++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/distilabel/steps/tasks/templates/urial.jinja2 b/src/distilabel/steps/tasks/templates/urial.jinja2 index 123d3acbb0..bdc49326ec 100644 --- a/src/distilabel/steps/tasks/templates/urial.jinja2 +++ b/src/distilabel/steps/tasks/templates/urial.jinja2 @@ -9,7 +9,9 @@ Your responses must not contain any fake, harmful, unethical, racist, sexist, to Your response must be socially responsible, and thus you can refuse to answer some controversial topics. {% for message in messages %} -# {{ message.role | capitalize }} +# {{ message.role | capitalize }}: {{ message.content }} {% endfor %} + +# Assistant: diff --git a/src/distilabel/steps/tasks/urial.py b/src/distilabel/steps/tasks/urial.py index 8ae8b40518..5abda3b478 100644 --- a/src/distilabel/steps/tasks/urial.py +++ b/src/distilabel/steps/tasks/urial.py @@ -50,6 +50,10 @@ def format_input(self, input: Dict[str, Any]) -> "ChatType": if "instruction" in input else input["conversation"] ) + + if input["messages"][-1]["role"] != "user": + raise ValueError("The last message must be from the user.") + return [{"role": "user", "content": self._template.render(messages=messages)}] @property From 1417a6ca88087a585f5b8e8395bc6afd01dbad3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Wed, 21 Aug 2024 16:49:26 +0200 Subject: [PATCH 3/7] Fix checking last message --- src/distilabel/steps/tasks/text_generation.py | 4 ++-- src/distilabel/steps/tasks/urial.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/distilabel/steps/tasks/text_generation.py b/src/distilabel/steps/tasks/text_generation.py index aa248c700b..aeb74c9ec7 100644 --- a/src/distilabel/steps/tasks/text_generation.py +++ b/src/distilabel/steps/tasks/text_generation.py @@ -192,7 +192,7 @@ def inputs(self) -> List[str]: """The input for the task are the `messages`.""" return ["messages"] - def format_input(self, input: Dict[str, Any]) -> ChatType: + def format_input(self, input: Dict[str, Any]) -> "ChatType": """The input is formatted as a `ChatType` assuming that the messages provided are already formatted that way i.e. following the OpenAI chat format.""" @@ -216,7 +216,7 @@ def outputs(self) -> List[str]: return ["generation", "model_name"] def format_output( - self, output: Union[str, None], input: Dict[str, Any] + self, output: Union[str, None], input: Union[Dict[str, Any], None] = None ) -> Dict[str, Any]: """The output is formatted as a dictionary with the `generation`. The `model_name` will be automatically included within the `process` method of `Task`.""" diff --git a/src/distilabel/steps/tasks/urial.py b/src/distilabel/steps/tasks/urial.py index 5abda3b478..cf2c1392d8 100644 --- a/src/distilabel/steps/tasks/urial.py +++ b/src/distilabel/steps/tasks/urial.py @@ -51,7 +51,7 @@ def format_input(self, input: Dict[str, Any]) -> "ChatType": else input["conversation"] ) - if input["messages"][-1]["role"] != "user": + if messages[-1]["role"] != "user": raise ValueError("The last message must be from the user.") return [{"role": "user", "content": self._template.render(messages=messages)}] From 8f95932c1e9501bebab6cf2e5b01156d5c30069f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Wed, 21 Aug 2024 17:02:36 +0200 Subject: [PATCH 4/7] Add `format_output` logic --- src/distilabel/steps/tasks/urial.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/distilabel/steps/tasks/urial.py b/src/distilabel/steps/tasks/urial.py index cf2c1392d8..72bdf14eb9 100644 --- a/src/distilabel/steps/tasks/urial.py +++ b/src/distilabel/steps/tasks/urial.py @@ -34,7 +34,6 @@ def load(self) -> None: / "steps" / "tasks" / "templates" - / "ultrafeedback" / "urial.jinja2" ) @@ -64,5 +63,6 @@ def format_output( self, output: Union[str, None], input: Union[Dict[str, Any], None] = None ) -> Dict[str, Any]: if output is None: - return {} - pass + return {"generation": None} + + return {"generation": output.split("\n\n# User")[-1]} From 0b0bd93231d051c832e50e2260d32873b2df9391 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Wed, 21 Aug 2024 17:32:14 +0200 Subject: [PATCH 5/7] Refine `format_output` and add docstring --- .../steps/tasks/templates/urial.jinja2 | 1 - src/distilabel/steps/tasks/urial.py | 56 ++++++++++++++++++- 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/src/distilabel/steps/tasks/templates/urial.jinja2 b/src/distilabel/steps/tasks/templates/urial.jinja2 index bdc49326ec..09a45bcc58 100644 --- a/src/distilabel/steps/tasks/templates/urial.jinja2 +++ b/src/distilabel/steps/tasks/templates/urial.jinja2 @@ -13,5 +13,4 @@ Your response must be socially responsible, and thus you can refuse to answer so {{ message.content }} {% endfor %} - # Assistant: diff --git a/src/distilabel/steps/tasks/urial.py b/src/distilabel/steps/tasks/urial.py index 72bdf14eb9..5a1a57f973 100644 --- a/src/distilabel/steps/tasks/urial.py +++ b/src/distilabel/steps/tasks/urial.py @@ -25,6 +25,55 @@ class URIAL(Task): + """Generates a response using a non-instruct fine-tuned model. + + `URIAL` is a pre-defined task that generates a response using a non-instruct fine-tuned + model. This task is used to generate a response based on the conversation provided as + input. + + Input columns: + - instruction (`str`, optional): The instruction to generate a response from. + - conversation (`List[Dict[str, str]]`, optional): The conversation to generate + a response from (the last message must be from the user). + + Output columns: + - generation (`str`): The generated response. + - model_name (`str`): The name of the model used to generate the response. + + Categories: + - text-generation + + Examples: + + Generate text from an instruction: + + ```python + from distilabel.llms import vLLM + from distilabel.steps.tasks import URIAL + + step = URIAL( + llm=vLLM( + model="meta-llama/Meta-Llama-3.1-8B", + generation_kwargs={"temperature": 0.7}, + ), + ) + + step.load() + + results = next( + step.process(inputs=[{"instruction": "What's the most most common type of cloud?"}]) + ) + # [ + # { + # 'instruction': "What's the most most common type of cloud?", + # 'generation': 'Clouds are classified into three main types, high, middle, and low. The most common type of cloud is the middle cloud.', + # 'distilabel_metadata': {...}, + # 'model_name': 'meta-llama/Meta-Llama-3.1-8B' + # } + # ] + ``` + """ + def load(self) -> None: """Loads the Jinja2 template for the given `aspect`.""" super().load() @@ -65,4 +114,9 @@ def format_output( if output is None: return {"generation": None} - return {"generation": output.split("\n\n# User")[-1]} + response = output.split("\n\n# User")[0] + if response.startswith("\n\n"): + response = response[2:] + response = response.strip() + + return {"generation": response} From 8d2bcadc0738296630e6280c85c000e84b7522f9 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Wed, 21 Aug 2024 17:33:54 +0200 Subject: [PATCH 6/7] Add `References` --- src/distilabel/steps/tasks/urial.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/distilabel/steps/tasks/urial.py b/src/distilabel/steps/tasks/urial.py index 5a1a57f973..ed0e72d969 100644 --- a/src/distilabel/steps/tasks/urial.py +++ b/src/distilabel/steps/tasks/urial.py @@ -43,6 +43,9 @@ class URIAL(Task): Categories: - text-generation + References: + - [The Unlocking Spell on Base LLMs: Rethinking Alignment via In-Context Learning](https://arxiv.org/abs/2312.01552) + Examples: Generate text from an instruction: From 53230acf9c7528839658e3880016f90d2061892d Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Wed, 21 Aug 2024 17:42:22 +0200 Subject: [PATCH 7/7] Add `URIAL` unit tests --- .../unit/steps/tasks/test_text_generation.py | 5 +- tests/unit/steps/tasks/test_urial.py | 72 +++++++++++++++++++ 2 files changed, 73 insertions(+), 4 deletions(-) create mode 100644 tests/unit/steps/tasks/test_urial.py diff --git a/tests/unit/steps/tasks/test_text_generation.py b/tests/unit/steps/tasks/test_text_generation.py index dd43530cd9..2ed399b237 100644 --- a/tests/unit/steps/tasks/test_text_generation.py +++ b/tests/unit/steps/tasks/test_text_generation.py @@ -21,11 +21,8 @@ class TestTextGeneration: def test_format_input(self) -> None: - pipeline = Pipeline(name="unit-test-pipeline") llm = DummyLLM() - task = TextGeneration( - name="task", llm=llm, pipeline=pipeline, use_system_prompt=False - ) + task = TextGeneration(name="task", llm=llm, use_system_prompt=False) assert task.format_input({"instruction": "test", "system_prompt": "test"}) == [ {"role": "user", "content": "test"} diff --git a/tests/unit/steps/tasks/test_urial.py b/tests/unit/steps/tasks/test_urial.py new file mode 100644 index 0000000000..2075d98e6e --- /dev/null +++ b/tests/unit/steps/tasks/test_urial.py @@ -0,0 +1,72 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from distilabel.steps.tasks.urial import URIAL + +from tests.unit.conftest import DummyLLM + + +class TestURIAL: + def test_format_input(self) -> None: + task = URIAL(llm=DummyLLM()) + task.load() + assert task.format_input({"instruction": "test"}) == [ + { + "role": "user", + "content": '# Instruction\n\nBelow is a list of conversations between a human and an AI assistant (you). \nUsers place their queries under "# User:", and your responses are under "# Assistant:".\nYou are a helpful, respectful, and honest assistant.\nYou should always answer as helpfully as possible while ensuring safety.\nYour answers should be well-structured and provide detailed information. They should also have an engaging tone.\nYour responses must not contain any fake, harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, even if it may be helpful.\nYour response must be socially responsible, and thus you can refuse to answer some controversial topics.\n\n\n# User:\n\ntest\n\n# Assistant:', + } + ] + + def test_format_input_with_conversation(self) -> None: + task = URIAL(llm=DummyLLM()) + task.load() + assert task.format_input( + { + "conversation": [ + {"role": "user", "content": "test"}, + {"role": "assistant", "content": "test"}, + {"role": "user", "content": "test"}, + ] + } + ) == [ + { + "role": "user", + "content": '# Instruction\n\nBelow is a list of conversations between a human and an AI assistant (you). \nUsers place their queries under "# User:", and your responses are under "# Assistant:".\nYou are a helpful, respectful, and honest assistant.\nYou should always answer as helpfully as possible while ensuring safety.\nYour answers should be well-structured and provide detailed information. They should also have an engaging tone.\nYour responses must not contain any fake, harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, even if it may be helpful.\nYour response must be socially responsible, and thus you can refuse to answer some controversial topics.\n\n\n# User:\n\ntest\n\n# Assistant:\n\ntest\n\n# User:\n\ntest\n\n# Assistant:', + } + ] + + def test_format_input_raise_valueerror(self) -> None: + task = URIAL(llm=DummyLLM()) + task.load() + + with pytest.raises(ValueError, match="The last message must be from the user."): + assert task.format_input( + { + "conversation": [ + {"role": "user", "content": "test"}, + {"role": "assistant", "content": "test"}, + ] + } + ) + + def test_format_output(self) -> None: + task = URIAL(llm=DummyLLM()) + task.load() + + assert task.format_output( + output=" \n\noutput\n\n# User:", input={"instruction": "test"} + ) == { + "generation": "output", + }