diff --git a/src/distilabel/steps/tasks/sentence_transformers.py b/src/distilabel/steps/tasks/sentence_transformers.py index b1ad50f5e1..666c89243f 100644 --- a/src/distilabel/steps/tasks/sentence_transformers.py +++ b/src/distilabel/steps/tasks/sentence_transformers.py @@ -48,11 +48,22 @@ " section: `## Positive`." ) +NEGATIVE_STYLE: Dict[str, str] = { + "negative": "can use similar words but must not be related to the anchor sentence", + "hard-negative": ( + "is a 'hard negative' that meets the following criteria:\n" + "- Uses similar keywords or phrases as the anchor sentence\n" + "- Has a similar grammatical structure or syntax\n" + "- Is not related to the anchor sentence, but could be mistaken for it\n" + "Try to create a negative sentence that would be challenging for a model to distinguish " + "from the positive sentence" + ), +} + POSITIVE_NEGATIVE_SYSTEM_PROMPT: str = ( "Your task is to generate a positive and a negative sentence given an anchor sentence.{context}" " The positive sentence has to {action_sentence} the anchor sentence, while the negative" - " sentence can use similar words but must not be related to the anchor sentence. You" - " must output only two new sections: `## Positive` and `## Negative`." + " sentence {negative_style}. You must output only two new sections: `## Positive` and `## Negative`." ) CONTEXT_INTRO: Final[str] = " Take into account the context given." @@ -63,9 +74,9 @@ class GenerateSentencePair(Task): `GenerateSentencePair` is a pre-defined task that given an anchor sentence generates a positive sentence related to the anchor and optionally a negative sentence unrelated - to the anchor. Optionally, you can give a context to guide the LLM towards more specific - behavior. This task is useful to generate training datasets for training embeddings - models. + to the anchor or similar to it. Optionally, you can give a context to guide the LLM + towards more specific behavior. This task is useful to generate training datasets for + training embeddings models. Attributes: triplet: a flag to indicate if the task should generate a triplet of sentences @@ -73,13 +84,18 @@ class GenerateSentencePair(Task): action: the action to perform to generate the positive sentence. context: the context to use for the generation. Can be helpful to guide the LLM towards more specific context. Not used by default. + hard_negative: A flag to indicate if the negative should be a hard-negative or not. + Hard negatives make it hard for the model to distinguish against the positive, + with a higher degree of semantic similarity. Input columns: - anchor (`str`): The anchor sentence to generate the positive and negative sentences. Output columns: - positive (`str`): The positive sentence related to the `anchor`. - - negative (`str`): The negative sentence unrelated to the `anchor` if `triplet=True`. + - negative (`str`): The negative sentence unrelated to the `anchor` if `triplet=True`, + or more similar to the positive to make it more challenging for a model to distinguish + in case `hard_negative=True`. - model_name (`str`): The name of the model that was used to generate the sentences. Categories: @@ -97,8 +113,8 @@ class GenerateSentencePair(Task): triplet=True, # `False` to generate only positive action="paraphrase", llm=InferenceEndpointsLLM( - model_id="meta-llama/Meta-Llama-3-70B-Instruct", - tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct", + model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct", ), input_batch_size=10, ) @@ -118,8 +134,8 @@ class GenerateSentencePair(Task): triplet=True, # `False` to generate only positive action="semantically-similar", llm=InferenceEndpointsLLM( - model_id="meta-llama/Meta-Llama-3-70B-Instruct", - tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct", + model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct", ), input_batch_size=10, ) @@ -139,8 +155,8 @@ class GenerateSentencePair(Task): triplet=True, # `False` to generate only positive action="query", llm=InferenceEndpointsLLM( - model_id="meta-llama/Meta-Llama-3-70B-Instruct", - tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct", + model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct", ), input_batch_size=10, ) @@ -160,8 +176,8 @@ class GenerateSentencePair(Task): triplet=True, # `False` to generate only positive action="answer", llm=InferenceEndpointsLLM( - model_id="meta-llama/Meta-Llama-3-70B-Instruct", - tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct", + model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct", ), input_batch_size=10, ) @@ -182,8 +198,31 @@ class GenerateSentencePair(Task): action="query", context="Argilla is an open-source data curation platform for LLMs.", llm=InferenceEndpointsLLM( - model_id="meta-llama/Meta-Llama-3-70B-Instruct", - tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct", + model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct", + ), + input_batch_size=10, + ) + + generate_sentence_pair.load() + + result = generate_sentence_pair.process([{"anchor": "I want to generate queries for my LLM."}]) + ``` + + Generating Hard-negatives (**applies to every action**): + + ```python + from distilabel.steps.tasks import GenerateSentencePair + from distilabel.llms import InferenceEndpointsLLM + + generate_sentence_pair = GenerateSentencePair( + triplet=True, # `False` to generate only positive + action="query", + context="Argilla is an open-source data curation platform for LLMs.", + hard_negative=True, + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct", ), input_batch_size=10, ) @@ -192,10 +231,12 @@ class GenerateSentencePair(Task): result = generate_sentence_pair.process([{"anchor": "I want to generate queries for my LLM."}]) ``` + """ triplet: bool = False action: GenerationAction + hard_negative: bool = False context: str = "" def load(self) -> None: @@ -229,12 +270,19 @@ def format_input(self, input: Dict[str, Any]) -> "ChatType": A list of dictionaries containing the system and user interactions. """ action_sentence = GENERATION_ACTION_SENTENCES[self.action] + + format_system_prompt = { + "action_sentence": action_sentence, + "context": CONTEXT_INTRO if self.context else "", + } + if self.triplet: + format_system_prompt["negative_style"] = NEGATIVE_STYLE[ + "hard-negative" if self.hard_negative else "negative" + ] + system_prompt = ( POSITIVE_NEGATIVE_SYSTEM_PROMPT if self.triplet else POSITIVE_SYSTEM_PROMPT - ).format( - action_sentence=action_sentence, - context=CONTEXT_INTRO if self.context else "", - ) + ).format(**format_system_prompt) return [ {"role": "system", "content": system_prompt}, diff --git a/tests/unit/steps/tasks/test_sentence_transformers.py b/tests/unit/steps/tasks/test_sentence_transformers.py index 2f81240755..099de2ff09 100644 --- a/tests/unit/steps/tasks/test_sentence_transformers.py +++ b/tests/unit/steps/tasks/test_sentence_transformers.py @@ -17,6 +17,7 @@ import pytest from distilabel.steps.tasks.sentence_transformers import ( CONTEXT_INTRO, + NEGATIVE_STYLE, POSITIVE_NEGATIVE_SYSTEM_PROMPT, POSITIVE_SYSTEM_PROMPT, GenerateSentencePair, @@ -28,29 +29,57 @@ class TestGenerateSentencePair: @pytest.mark.parametrize( - "action,triplet,system_prompt", + "action,triplet,hard_negative,system_prompt", [ ( "paraphrase", True, + False, POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( - action_sentence="paraphrase", context="" + action_sentence="paraphrase", + context="", + negative_style=NEGATIVE_STYLE["negative"], ), ), ( "paraphrase", + True, + True, + POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( + action_sentence="paraphrase", + context="", + negative_style=NEGATIVE_STYLE["hard-negative"], + ), + ), + ( + "paraphrase", + False, False, POSITIVE_SYSTEM_PROMPT.format(action_sentence="paraphrase", context=""), ), ( "semantically-similar", True, + False, POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( - action_sentence="be semantically similar to", context="" + action_sentence="be semantically similar to", + context="", + negative_style=NEGATIVE_STYLE["negative"], ), ), ( "semantically-similar", + True, + True, + POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( + action_sentence="be semantically similar to", + context="", + negative_style=NEGATIVE_STYLE["hard-negative"], + ), + ), + ( + "semantically-similar", + False, False, POSITIVE_SYSTEM_PROMPT.format( action_sentence="be semantically similar to", context="" @@ -59,12 +88,26 @@ class TestGenerateSentencePair: ( "query", True, + False, POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( - action_sentence="be a query for", context="" + action_sentence="be a query for", + context="", + negative_style=NEGATIVE_STYLE["negative"], ), ), ( "query", + True, + True, + POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( + action_sentence="be a query for", + context="", + negative_style=NEGATIVE_STYLE["hard-negative"], + ), + ), + ( + "query", + False, False, POSITIVE_SYSTEM_PROMPT.format( action_sentence="be a query for", context="" @@ -73,12 +116,26 @@ class TestGenerateSentencePair: ( "answer", True, + False, POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( - action_sentence="be an answer for", context="" + action_sentence="be an answer for", + context="", + negative_style=NEGATIVE_STYLE["negative"], ), ), ( "answer", + True, + True, + POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( + action_sentence="be an answer for", + context="", + negative_style=NEGATIVE_STYLE["hard-negative"], + ), + ), + ( + "answer", + False, False, POSITIVE_SYSTEM_PROMPT.format( action_sentence="be an answer for", context="" @@ -87,9 +144,15 @@ class TestGenerateSentencePair: ], ) def test_format_input( - self, action: GenerationAction, triplet: bool, system_prompt: str + self, + action: GenerationAction, + triplet: bool, + hard_negative: bool, + system_prompt: str, ) -> None: - task = GenerateSentencePair(llm=DummyLLM(), action=action, triplet=triplet) + task = GenerateSentencePair( + llm=DummyLLM(), action=action, triplet=triplet, hard_negative=hard_negative + ) task.load() content = "## Anchor\n\nThis is a unit test\n" assert task.format_input({"anchor": "This is a unit test"}) == [ @@ -98,18 +161,32 @@ def test_format_input( ] @pytest.mark.parametrize( - "action,triplet,system_prompt", + "action,triplet,hard_negative,system_prompt", [ ( "paraphrase", True, + False, POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( - action_sentence="paraphrase", context=CONTEXT_INTRO + action_sentence="paraphrase", + context=CONTEXT_INTRO, + negative_style=NEGATIVE_STYLE["negative"], + ), + ), + ( + "paraphrase", + True, + True, + POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( + action_sentence="paraphrase", + context=CONTEXT_INTRO, + negative_style=NEGATIVE_STYLE["hard-negative"], ), ), ( "paraphrase", False, + False, POSITIVE_SYSTEM_PROMPT.format( action_sentence="paraphrase", context=CONTEXT_INTRO ), @@ -117,13 +194,27 @@ def test_format_input( ( "semantically-similar", True, + False, POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( - action_sentence="be semantically similar to", context=CONTEXT_INTRO + action_sentence="be semantically similar to", + context=CONTEXT_INTRO, + negative_style=NEGATIVE_STYLE["negative"], + ), + ), + ( + "semantically-similar", + True, + True, + POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( + action_sentence="be semantically similar to", + context=CONTEXT_INTRO, + negative_style=NEGATIVE_STYLE["hard-negative"], ), ), ( "semantically-similar", False, + False, POSITIVE_SYSTEM_PROMPT.format( action_sentence="be semantically similar to", context=CONTEXT_INTRO ), @@ -131,13 +222,27 @@ def test_format_input( ( "query", True, + False, POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( - action_sentence="be a query for", context=CONTEXT_INTRO + action_sentence="be a query for", + context=CONTEXT_INTRO, + negative_style=NEGATIVE_STYLE["negative"], + ), + ), + ( + "query", + True, + True, + POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( + action_sentence="be a query for", + context=CONTEXT_INTRO, + negative_style=NEGATIVE_STYLE["hard-negative"], ), ), ( "query", False, + False, POSITIVE_SYSTEM_PROMPT.format( action_sentence="be a query for", context=CONTEXT_INTRO ), @@ -145,13 +250,27 @@ def test_format_input( ( "answer", True, + False, POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( - action_sentence="be an answer for", context=CONTEXT_INTRO + action_sentence="be an answer for", + context=CONTEXT_INTRO, + negative_style=NEGATIVE_STYLE["negative"], + ), + ), + ( + "answer", + True, + True, + POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( + action_sentence="be an answer for", + context=CONTEXT_INTRO, + negative_style=NEGATIVE_STYLE["hard-negative"], ), ), ( "answer", False, + False, POSITIVE_SYSTEM_PROMPT.format( action_sentence="be an answer for", context=CONTEXT_INTRO ), @@ -159,7 +278,11 @@ def test_format_input( ], ) def test_format_input_with_context( - self, action: GenerationAction, triplet: bool, system_prompt: str + self, + action: GenerationAction, + triplet: bool, + hard_negative: bool, + system_prompt: str, ) -> None: context = "This is your context." task = GenerateSentencePair( @@ -167,10 +290,10 @@ def test_format_input_with_context( action=action, triplet=triplet, context=context, + hard_negative=hard_negative, ) task.load() content = f"## Context\n\n{context}\n\n## Anchor\n\nThis is a unit test\n" - # content = f"## Anchor\n\nThis is a unit test\n## Context\n\n{context}" assert task.format_input({"anchor": "This is a unit test"}) == [ {"role": "system", "content": system_prompt}, {"role": "user", "content": content},