Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v2] Refactor retrieval #1750

Open
wants to merge 10 commits into
base: v2.0.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 118 additions & 124 deletions mteb/abstasks/AbsTaskRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
import os
from collections import defaultdict
from pathlib import Path
from time import time
from typing import Any
Expand All @@ -15,7 +16,7 @@
from ..evaluation.evaluators.utils import make_score_dict
from ..load_results.task_results import ScoresDict
from .AbsTask import AbsTask
from .dataloaders import HFDataLoader
from .dataloaders import RetrievalDataLoader
from .TaskMetadata import DescriptiveStatistics

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -126,108 +127,64 @@ class AbsTaskRetrieval(AbsTask):

ignore_identical_ids: bool = False
abstask_prompt = "Retrieve text based on user query."

instructions = None
top_ranked = None

def load_data(self, **kwargs):
if self.data_loaded:
return
self.corpus, self.queries, self.relevant_docs = {}, {}, {}
self.instructions, self.top_ranked = None, None
dataset_path = self.metadata.dataset["path"]
hf_repo_qrels = (
dataset_path + "-qrels" if "clarin-knext" in dataset_path else None
)
if not self.is_multilingual:
for split in kwargs.get("eval_splits", self.metadata.eval_splits):
corpus, queries, qrels, instructions, top_ranked = HFDataLoader(
hf_repo=dataset_path,
hf_repo_qrels=hf_repo_qrels,
streaming=False,
keep_in_memory=False,
trust_remote_code=self.metadata.dataset.get(
"trust_remote_code", False
),
).load(split=split)
# Conversion from DataSet
queries = {query["id"]: query["text"] for query in queries}
corpus = {
doc["id"]: doc.get("title", "") + " " + doc["text"]
for doc in corpus
}
self.corpus[split], self.queries[split], self.relevant_docs[split] = (
corpus,
queries,
qrels,
)

# optional args
if instructions:
self.instructions = {
split: {
inst["query-id"]: inst["instruction"]
for inst in instructions
}
}
if top_ranked:
self.top_ranked = {
split: {tr["query-id"]: tr["corpus-ids"] for tr in top_ranked}
}
else:
if not isinstance(self.metadata.eval_langs, dict):
raise ValueError("eval_langs must be a dict for multilingual tasks")
for lang in self.metadata.eval_langs:
self.corpus[lang], self.queries[lang], self.relevant_docs[lang] = (
{},
{},
{},
)
for split in kwargs.get("eval_splits", self.metadata.eval_splits):
corpus, queries, qrels, instructions, top_ranked = HFDataLoader(
hf_repo=dataset_path,
hf_repo_qrels=hf_repo_qrels,
streaming=False,
keep_in_memory=False,
trust_remote_code=self.metadata.dataset.get(
"trust_remote_code", False
),
).load(split=split, config=lang)
# Conversion from DataSet
queries = {query["id"]: query["text"] for query in queries}
corpus = {
doc["id"]: doc.get("title", "") + " " + doc["text"]
for doc in corpus
}
(
self.corpus[lang][split],
self.queries[lang][split],
self.relevant_docs[lang][split],
) = (
corpus,
queries,
qrels,
)
self.corpus = defaultdict(dict)
self.queries = defaultdict(dict)
self.relevant_docs = defaultdict(dict)
self.instructions = None
self.top_ranked = None

# optional args
if instructions:
if self.instructions is None:
self.instructions = {}
self.instructions[lang] = {
split: {
inst["query-id"]: inst["instruction"]
for inst in instructions
}
}
if top_ranked:
if self.top_ranked is None:
self.top_ranked = {}
self.top_ranked[lang] = {
split: {
tr["query-id"]: tr["corpus-ids"] for tr in top_ranked
}
}
dataset_path = self.metadata.dataset["path"]
eval_splits = kwargs.get("eval_splits", self.metadata.eval_splits)
trust_remote_code = self.metadata.dataset.get("trust_remote_code", False)

def process_data(split: str, lang: str | None = None):
"""Helper function to load and process data for a given split and language"""
corpus, queries, qrels, instructions, top_ranked = RetrievalDataLoader(
hf_repo=dataset_path,
trust_remote_code=trust_remote_code,
split=split,
config=lang,
).load()

if lang:
self.corpus[lang][split] = corpus
self.queries[lang][split] = queries
self.relevant_docs[lang][split] = qrels
else:
self.corpus[split] = corpus
self.queries[split] = queries
self.relevant_docs[split] = qrels

if instructions:
if self.instructions is None:
self.instructions = defaultdict(dict)
if lang:
self.instructions[lang][split] = instructions
else:
self.instructions[split] = instructions

if top_ranked:
if self.top_ranked is None:
self.top_ranked = defaultdict(dict)
if lang:
self.top_ranked[lang][split] = top_ranked
else:
self.top_ranked[split] = top_ranked

if self.is_multilingual:
for lang in self.metadata.eval_langs:
for split in eval_splits:
process_data(split, lang)
else:
for split in eval_splits:
process_data(split)
self.data_loaded = True

def evaluate(
Expand Down Expand Up @@ -255,54 +212,91 @@ def evaluate(
logger.info(f"Subset: {hf_subset}")

if hf_subset == "default" and "default" not in self.corpus:
corpus, queries, relevant_docs = (
self.corpus[split],
self.queries[split],
self.relevant_docs[split],
)
if self.top_ranked is not None:
kwargs["top_ranked"] = self.top_ranked[split]
if self.instructions is not None:
kwargs["instructions"] = self.instructions[split]
corpus = self.corpus[split]
queries = self.queries[split]
relevant_docs = self.relevant_docs[split]
top_ranked = self.top_ranked[split] if self.top_ranked else None
instructions = self.instructions[split] if self.instructions else None
else:
corpus, queries, relevant_docs = (
self.corpus[hf_subset][split],
self.queries[hf_subset][split],
self.relevant_docs[hf_subset][split],
corpus = self.corpus[hf_subset][split]
queries = self.queries[hf_subset][split]
relevant_docs = self.relevant_docs[hf_subset][split]
top_ranked = (
self.top_ranked[hf_subset][split] if self.top_ranked else None
)
instructions = (
self.instructions[hf_subset][split] if self.instructions else None
)
if self.top_ranked is not None:
kwargs["top_ranked"] = self.top_ranked[hf_subset][split]
if self.instructions is not None:
kwargs["instructions"] = self.instructions[hf_subset][split]

scores[hf_subset] = self._evaluate_subset(
retriever, corpus, queries, relevant_docs, hf_subset, **kwargs
retriever,
corpus,
queries,
relevant_docs,
hf_subset,
top_ranked,
instructions,
**kwargs,
)
return scores

def _evaluate_subset(
self, retriever, corpus, queries, relevant_docs, hf_subset: str, **kwargs
self,
retriever: RetrievalEvaluator,
corpus: dict[str, dict[str, str]],
queries: dict[str, str],
relevant_docs: dict[str, dict[str, int]],
hf_subset: str,
top_ranked: dict[str, list[str]] | None = None,
instructions: dict[str, str] | None = None,
save_predictions: bool = False,
export_errors: bool = False,
save_qrels: bool = False,
output_folder: str = "results",
results: dict[str, dict[str, float]] | None = None,
top_k: int | None = None,
**kwargs,
) -> ScoresDict:
if "results" in kwargs:
# reranking has already been done
results = kwargs["results"]
else:
"""Evaluate the retrieval task for a given subset of the dataset.

Args:
retriever: Evaluation object
corpus: Corpus to evaluate on
queries: Queries to evaluate on
relevant_docs: Relevant documents for the queries
hf_subset: Subset of the dataset
top_ranked: Top ranked documents (used for reranking)
instructions: Instructions for the queries (used for InstructRetrieval/Reranking)
save_predictions: Whether to save the predictions
export_errors: Whether to export errors
save_qrels: Whether to save the qrels
output_folder: Folder to save the results
results: Results from retrieval from previous run
top_k: Top k documents to consider
**kwargs: kwargs

Returns:
ScoresDict: Evaluation scores
"""
if not results:
# perform the retrieval here
start_time = time()
results = retriever(corpus, queries, **kwargs)
results = retriever(
corpus,
queries,
instructions=instructions,
top_ranked=top_ranked,
**kwargs,
)
end_time = time()
logger.info(f"Time taken to retrieve: {end_time - start_time:.2f} seconds")

save_predictions = kwargs.get("save_predictions", False)
export_errors = kwargs.get("export_errors", False)
save_qrels = kwargs.get("save_qrels", False)
if save_predictions or export_errors or save_qrels:
output_folder = Path(kwargs.get("output_folder", "results"))
output_folder = Path(output_folder)
if not os.path.isdir(output_folder):
os.makedirs(output_folder)

if save_predictions:
top_k = kwargs.get("top_k", None)
if top_k is not None:
for qid in list(results.keys()):
doc_ids = set(
Expand Down Expand Up @@ -345,7 +339,7 @@ def _evaluate_subset(
if export_errors:
errors = {}

top_k = kwargs.get("top_k", 1)
top_k = top_k or 1
if not save_predictions and top_k == 1:
for qid in results.keys():
doc_scores = results[qid]
Expand Down
Loading
Loading