Skip to content

Commit

Permalink
Add hard-negative flag to include similar challenging negatives on tr…
Browse files Browse the repository at this point in the history
…iplets (#856)

* Add hard-negative flag to include similar challenging negatives on triplets

* Update src/distilabel/steps/tasks/sentence_transformers.py

Co-authored-by: Gabriel Martín Blázquez <[email protected]>

---------

Co-authored-by: Gabriel Martín Blázquez <[email protected]>
  • Loading branch information
plaguss and gabrielmbmb authored Aug 6, 2024
1 parent 092c364 commit ff3f484
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 34 deletions.
88 changes: 68 additions & 20 deletions src/distilabel/steps/tasks/sentence_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,22 @@
" section: `## Positive`."
)

NEGATIVE_STYLE: Dict[str, str] = {
"negative": "can use similar words but must not be related to the anchor sentence",
"hard-negative": (
"is a 'hard negative' that meets the following criteria:\n"
"- Uses similar keywords or phrases as the anchor sentence\n"
"- Has a similar grammatical structure or syntax\n"
"- Is not related to the anchor sentence, but could be mistaken for it\n"
"Try to create a negative sentence that would be challenging for a model to distinguish "
"from the positive sentence"
),
}

POSITIVE_NEGATIVE_SYSTEM_PROMPT: str = (
"Your task is to generate a positive and a negative sentence given an anchor sentence.{context}"
" The positive sentence has to {action_sentence} the anchor sentence, while the negative"
" sentence can use similar words but must not be related to the anchor sentence. You"
" must output only two new sections: `## Positive` and `## Negative`."
" sentence {negative_style}. You must output only two new sections: `## Positive` and `## Negative`."
)

CONTEXT_INTRO: Final[str] = " Take into account the context given."
Expand All @@ -63,23 +74,28 @@ class GenerateSentencePair(Task):
`GenerateSentencePair` is a pre-defined task that given an anchor sentence generates
a positive sentence related to the anchor and optionally a negative sentence unrelated
to the anchor. Optionally, you can give a context to guide the LLM towards more specific
behavior. This task is useful to generate training datasets for training embeddings
models.
to the anchor or similar to it. Optionally, you can give a context to guide the LLM
towards more specific behavior. This task is useful to generate training datasets for
training embeddings models.
Attributes:
triplet: a flag to indicate if the task should generate a triplet of sentences
(anchor, positive, negative). Defaults to `False`.
action: the action to perform to generate the positive sentence.
context: the context to use for the generation. Can be helpful to guide the LLM
towards more specific context. Not used by default.
hard_negative: A flag to indicate if the negative should be a hard-negative or not.
Hard negatives make it hard for the model to distinguish against the positive,
with a higher degree of semantic similarity.
Input columns:
- anchor (`str`): The anchor sentence to generate the positive and negative sentences.
Output columns:
- positive (`str`): The positive sentence related to the `anchor`.
- negative (`str`): The negative sentence unrelated to the `anchor` if `triplet=True`.
- negative (`str`): The negative sentence unrelated to the `anchor` if `triplet=True`,
or more similar to the positive to make it more challenging for a model to distinguish
in case `hard_negative=True`.
- model_name (`str`): The name of the model that was used to generate the sentences.
Categories:
Expand All @@ -97,8 +113,8 @@ class GenerateSentencePair(Task):
triplet=True, # `False` to generate only positive
action="paraphrase",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
),
input_batch_size=10,
)
Expand All @@ -118,8 +134,8 @@ class GenerateSentencePair(Task):
triplet=True, # `False` to generate only positive
action="semantically-similar",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
),
input_batch_size=10,
)
Expand All @@ -139,8 +155,8 @@ class GenerateSentencePair(Task):
triplet=True, # `False` to generate only positive
action="query",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
),
input_batch_size=10,
)
Expand All @@ -160,8 +176,8 @@ class GenerateSentencePair(Task):
triplet=True, # `False` to generate only positive
action="answer",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
),
input_batch_size=10,
)
Expand All @@ -182,8 +198,31 @@ class GenerateSentencePair(Task):
action="query",
context="Argilla is an open-source data curation platform for LLMs.",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
),
input_batch_size=10,
)
generate_sentence_pair.load()
result = generate_sentence_pair.process([{"anchor": "I want to generate queries for my LLM."}])
```
Generating Hard-negatives (**applies to every action**):
```python
from distilabel.steps.tasks import GenerateSentencePair
from distilabel.llms import InferenceEndpointsLLM
generate_sentence_pair = GenerateSentencePair(
triplet=True, # `False` to generate only positive
action="query",
context="Argilla is an open-source data curation platform for LLMs.",
hard_negative=True,
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
),
input_batch_size=10,
)
Expand All @@ -192,10 +231,12 @@ class GenerateSentencePair(Task):
result = generate_sentence_pair.process([{"anchor": "I want to generate queries for my LLM."}])
```
"""

triplet: bool = False
action: GenerationAction
hard_negative: bool = False
context: str = ""

def load(self) -> None:
Expand Down Expand Up @@ -229,12 +270,19 @@ def format_input(self, input: Dict[str, Any]) -> "ChatType":
A list of dictionaries containing the system and user interactions.
"""
action_sentence = GENERATION_ACTION_SENTENCES[self.action]

format_system_prompt = {
"action_sentence": action_sentence,
"context": CONTEXT_INTRO if self.context else "",
}
if self.triplet:
format_system_prompt["negative_style"] = NEGATIVE_STYLE[
"hard-negative" if self.hard_negative else "negative"
]

system_prompt = (
POSITIVE_NEGATIVE_SYSTEM_PROMPT if self.triplet else POSITIVE_SYSTEM_PROMPT
).format(
action_sentence=action_sentence,
context=CONTEXT_INTRO if self.context else "",
)
).format(**format_system_prompt)

return [
{"role": "system", "content": system_prompt},
Expand Down
Loading

0 comments on commit ff3f484

Please sign in to comment.