-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Question] Hard negative mining #2697
Comments
Hello! It does not currently, although this would be a very valuable addition! I'd be very happy to receive a pull request for this. Please let me know if that sounds feasible!
|
Hi @tomaarsen, thanks for your reply! I could probably submit a PR for this, can you give me some initial feedback on the following before I submit a draft? import torch
import numpy as np
from sentence_transformers import SentenceTransformer, util, CrossEncoder
from datasets import Dataset
from tqdm.auto import tqdm
# Synthetic data for testing
def create_sample_pairs_dataset():
data = {
'anchor': [
"What are the health benefits of regular exercise?",
"How to improve your time management skills?",
"What is the capital of France?",
"How does photosynthesis work in plants?",
"What are the symptoms of a common cold?",
"How to cook a perfect steak?",
"What is the importance of cybersecurity?",
"How to start investing in stocks?",
"What are the best practices for remote work?",
"How does blockchain technology work?",
"What are the stages of the water cycle?",
"How to learn a new language quickly?",
"What are the effects of climate change?",
"How to maintain a healthy work-life balance?",
"What are the benefits of meditation?",
"How to build a successful startup?"
],
'positive': [
"Regular exercise improves cardiovascular health, strengthens muscles, and enhances mental well-being.",
"Improving time management skills involves setting priorities, using tools like calendars and planners, and avoiding procrastination.",
"The capital of France is Paris, a major European city known for its art, fashion, and culture.",
"Photosynthesis is the process by which plants convert sunlight into chemical energy, producing oxygen as a byproduct.",
"Common cold symptoms include a runny nose, sore throat, coughing, sneezing, and congestion.",
"To cook a perfect steak, season it well, use a hot pan, and let it rest before serving.",
"Cybersecurity is crucial for protecting sensitive information from cyber threats and maintaining privacy.",
"Starting to invest in stocks requires understanding the market, researching companies, and considering risks.",
"Best practices for remote work include setting up a dedicated workspace, maintaining regular hours, and communicating effectively.",
"Blockchain technology is a decentralized ledger system that ensures secure and transparent transactions.",
"The stages of the water cycle include evaporation, condensation, precipitation, and collection.",
"Learning a new language quickly involves consistent practice, immersion, and using language learning apps.",
"Climate change effects include rising temperatures, melting ice caps, and increased frequency of extreme weather events.",
"Maintaining a healthy work-life balance requires setting boundaries, prioritizing self-care, and managing time effectively.",
"Meditation benefits include reduced stress, improved concentration, and enhanced emotional health.",
"Building a successful startup involves identifying a market need, creating a solid business plan, and securing funding."
]
}
df = pd.DataFrame(data)
dataset = Dataset.from_pandas(df)
return dataset
def add_hard_negatives(dataset, embedding_model_name, cross_encoder_name, range_min=1, threshold=0.5, batch_size=8, use_gpu=True, negative_number=3):
"""
Add hard negatives to a dataset of (anchor, positive) pairs to create (anchor, positive, negative) triplets.
Args:
dataset (Dataset): The dataset containing (anchor, positive) pairs.
embedding_model_name (str): Name of the embedding model to use.
cross_encoder_name (str): Name of the cross encoder model to use.
range_min (int): Minimum rank of the closest matches to consider as negatives (e.g., if 2, the top 1 closest matches are not used).
threshold (float): Threshold for CrossEncoder similarity score.
batch_size (int): Batch size for processing.
use_gpu (bool): Whether to use GPU for searching.
negative_number (int): Number of negatives to sample.
Returns:
Dataset: A dataset containing (anchor, positive, negative) triplets.
"""
device = torch.device("cuda" if torch.cuda.is_available() and use_gpu else "cpu")
model = SentenceTransformer(embedding_model_name, device=device)
cross_encoder = CrossEncoder(cross_encoder_name, device=device)
# Calculate the value of k
k = negative_number + range_min + 1
# Combine anchor and positive sentences to get unique corpus
anchors = dataset['anchor']
positives = dataset['positive']
sentences = positives # Use only positives for negatives sampling
embeddings = model.encode(sentences, convert_to_tensor=True, device=device)
# Find top K matching entries for all queries in the corpus in batches
triplets_data = []
for start_idx in tqdm(range(0, len(anchors), batch_size), desc="Batches"):
end_idx = min(start_idx + batch_size, len(anchors))
batch_embeddings = model.encode(anchors[start_idx:end_idx], convert_to_tensor=True, device=device)
for idx, query_embedding in enumerate(batch_embeddings):
hits = util.semantic_search(query_embedding, embeddings, top_k=k)[0]
# Filter out the true positives
true_positive_idx = start_idx + idx
hits = [hit for hit in hits if sentences[hit['corpus_id']] != positives[true_positive_idx]]
if not hits:
continue # Skip if no potential negatives are found
# Use CrossEncoder to filter false negatives
cross_encoder_scores = cross_encoder.predict([[anchors[start_idx + idx], sentences[hit['corpus_id']]] for hit in hits])
# Apply threshold for true similarity
filtered_hits = [hits[i] for i in range(len(hits)) if cross_encoder_scores[i] < threshold]
# Sample negatives from the filtered hits
filtered_hits = filtered_hits[range_min:]
if len(filtered_hits) > negative_number:
filtered_hits = np.random.choice(filtered_hits, negative_number, replace=False)
if len(filtered_hits) == 0:
continue # Skip if no hard negatives found
# Create triplets (anchor, positive, negative)
positive = positives[true_positive_idx]
for hit in filtered_hits:
negative = sentences[hit['corpus_id']]
triplets_data.append({
'anchor': anchors[start_idx + idx],
'positive': positive,
'negative': negative
})
if len(triplets_data) == 0:
raise ValueError("No triplets were generated. Please check the parameters and dataset.")
triplets_dataset = Dataset.from_pandas(pd.DataFrame(triplets_data))
return triplets_dataset
# Example usage
if __name__ == "__main__":
embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
cross_encoder_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
threshold = 0.5 # Default threshold, can be adjusted
batch_size = 8 # Batch size for processing
range_min = 2 # Minimum rank of the closest matches to consider as negatives
use_gpu = True # Use GPU by default
negative_number = 3 # Number of negatives to sample
# Create sample pairs dataset
sample_pairs_dataset = create_sample_pairs_dataset()
# Generate hard negatives
hard_negative_dataset = add_hard_negatives(
dataset=sample_pairs_dataset,
embedding_model_name=embedding_model_name,
cross_encoder_name=cross_encoder_name,
range_min=range_min,
threshold=threshold,
batch_size=batch_size,
use_gpu=use_gpu,
negative_number=negative_number
) |
@austinmw This is a great starting point! I actually have quite a lot of ideas for expanding this helper function. Would you be okay with me taking over development based on this and opening a PR myself?
|
Was looking into this to try to reproduce something like this for the TREC RAG version of MSMARCO, this functionality would be very appreciated :D |
@tomaarsen absolutely, thanks! |
Hi, does sentence-transformers happen to have any utility methods to generate an expanded dataset with hard negatives from an input dataset and model?
The text was updated successfully, but these errors were encountered: