Skip to content

Golden Retriever - A Lightning framework for retriever architecture prototype

Notifications You must be signed in to change notification settings

Riccorl/golden-retriever

Repository files navigation

🦮 Golden Retriever

PyTorch Lightning Code style: black vscode

release gh-status

WIP: distributed-compatible codebase

A distributed-compatible codebase is under development. Check the distributed branch for the latest updates.

How to use

Install the library from PyPI:

pip install goldenretriever-core

or from source:

git clone https://github.com/Riccorl/golden-retriever.git
cd golden-retriever
pip install -e .

Usage

How to run an experiment

Training

Here a simple example on how to train a DPR-like Retriever on the NQ dataset. First download the dataset from DPR. The run the following code:

from goldenretriever.trainer import Trainer
from goldenretriever import GoldenRetriever
from goldenretriever.data.datasets import InBatchNegativesDataset

# create a retriever
retriever = GoldenRetriever(
    question_encoder="intfloat/e5-small-v2",
    passage_encoder="intfloat/e5-small-v2"
)

# create a dataset
train_dataset = InBatchNegativesDataset(
    name="webq_train",
    path="path/to/webq_train.json",
    tokenizer=retriever.question_tokenizer,
    question_batch_size=64,
    passage_batch_size=400,
    max_passage_length=64,
    shuffle=True,
)
val_dataset = InBatchNegativesDataset(
    name="webq_dev",
    path="path/to/webq_dev.json",
    tokenizer=retriever.question_tokenizer,
    question_batch_size=64,
    passage_batch_size=400,
    max_passage_length=64,
)

trainer = Trainer(
    retriever=retriever,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    max_steps=25_000,
    wandb_online_mode=True,
    wandb_project_name="golden-retriever-dpr",
    wandb_experiment_name="e5-small-webq",
    max_hard_negatives_to_mine=5,
)

# start training
trainer.train()

Evaluation

from goldenretriever.trainer import Trainer
from goldenretriever import GoldenRetriever
from goldenretriever.data.datasets import InBatchNegativesDataset

retriever = GoldenRetriever(
  question_encoder="",
  document_index="",
  device="cuda",
  precision="16",
)

test_dataset = InBatchNegativesDataset(
  name="test",
  path="",
  tokenizer=retriever.question_tokenizer,
  question_batch_size=64,
  passage_batch_size=400,
  max_passage_length=64,
)

trainer = Trainer(
  retriever=retriever,
  test_dataset=test_dataset,
  log_to_wandb=False,
  top_k=[20, 100]
)

trainer.test()

Inference

from goldenretriever import GoldenRetriever

retriever = GoldenRetriever(
    question_encoder="path/to/question/encoder",
    passage_encoder="path/to/passage/encoder",
    document_index="path/to/document/index"
)

# retrieve documents
retriever.retrieve("What is the capital of France?", k=5)

Data format

Input data

The retriever expects a jsonl file similar to DPR:

[
  {
  "question": "....",
  "answers": ["...", "...", "..."],
  "positive_ctxs": [{
    "title": "...",
    "text": "...."
  }],
  "negative_ctxs": ["..."],
  "hard_negative_ctxs": ["..."]
  },
  ...
]

Index data

The document to index can be either a jsonl file or a tsv file similar to DPR:

  • jsonl: each line is a json object with the following keys: id, text, metadata
  • tsv: each line is a tab-separated string with the id and text column, followed by any other column that will be stored in the metadata field

jsonl example:

[
  {
    "id": "...",
    "text": "...",
    "metadata": ["{...}"]
  },
  ...
]

tsv example:

id \t text \t any other column
...