Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Domain Expert Evaluator #5

Merged
merged 16 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,10 @@ python -m build

### ✅ TODO
- [ ] Add option to few-shot examples
- [x] Publish on PyPi
- [ ] Add custom types
- [ ] Testing!
- [ ] Add CI/CD for publishing
- [x] Publish on PyPi
- [x] Add more document evaluators (Microsoft)
- [x] Split Elo evaluator
- [x] Install as standalone CLI
Expand Down
21 changes: 18 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,20 @@ classifiers = [
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: System :: Benchmark",
]
dependencies = ["openai", "tenacity", "typer"]
dependencies = ["openai", "tenacity", "typer", "numpy"]

[project.optional-dependencies]
cli = ["typer[all]"]
dev = ["bandit==1.7.5", "black==23.10.0", "isort==5.12.0", "flake8==6.1.0", "flake8-black==0.3.6", "flake8-isort==6.1.0", "mypy==1.6.1"]
dev = [
"bandit==1.7.5",
"black==23.10.0",
"isort==5.12.0",
"flake8==6.1.0",
"flake8-black==0.3.6",
"flake8-isort==6.1.0",
"mypy==1.6.1",
"Flake8-pyproject==1.2.3",
]

[project.scripts]
ragelo = "ragelo.cli:app"
Expand All @@ -45,9 +54,15 @@ profile = "black"

[tool.mypy]
python_version = "3.11"
ignore_missing_imports = true
show_column_numbers = true
namespace_packages = true
exclude = ["build/", "dist/", "venv/"]

[tool.flake8]
ignore = ['E501', "W503"]
per-file-ignores = ['__init__.py:F401,F403']
exclude = ["build/", "dist/", "venv/"]


[tool.setuptools-git-versioning]
enabled = true
2 changes: 1 addition & 1 deletion ragelo/answer_evaluators/base_answer_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Dict, Set, Type

from ragelo.logger import logger
from ragelo.opeanai_client import OpenAiClient, set_credentials_from_file
from ragelo.utils.openai_client import OpenAiClient, set_credentials_from_file


class AnswerEvaluator:
Expand Down
1 change: 1 addition & 0 deletions ragelo/doc_evaluators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from ragelo.doc_evaluators.domain_expert import *
from ragelo.doc_evaluators.rdnam_evaluator import *
from ragelo.doc_evaluators.reasoner_evaluator import *
173 changes: 114 additions & 59 deletions ragelo/doc_evaluators/base_doc_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import os
from abc import abstractmethod
from collections import defaultdict
from typing import Any, Callable, Dict, Type
from contextlib import nullcontext
from functools import partial
from typing import Any, Callable, Dict, List, Set, Tuple, Type

from tenacity import RetryError

from ragelo.logger import logger
from ragelo.opeanai_client import OpenAiClient, set_credentials_from_file
from ragelo.utils.openai_client import OpenAiClient, set_credentials_from_file


class DocumentEvaluator:
Expand All @@ -29,69 +31,80 @@ def __init__(
self.output_file = output_file
self.queries = self._load_queries(query_path)
self.documents = self._load_documents(documents_path)
if verbose:
logger.setLevel("INFO")

if credentials_file:
set_credentials_from_file(credentials_file)

self.openai_client = OpenAiClient(model=model_name)
self.progress_bar: Callable = nullcontext
try:
from rich.progress import Progress

def get_answers(self):
self.progress_bar = partial(Progress, transient=True)
self.rich = True
except ImportError:
self.rich = False

def get_answers(self) -> Dict[str, Dict[str, Any]]:
"""Runs the evaluator and saves the results to a file"""
skip_docs = set()
if os.path.isfile(self.output_file) and not self.force:
for line in csv.reader(open(self.output_file)):
qid, did, answer = line
skip_docs.add((qid, did))
if self.force and os.path.isfile(self.output_file):
logger.warning(f"Removing existing {self.output_file}!")
os.remove(self.output_file)
if len(skip_docs) > 0:
logger.warning(
f"Skipping {len(skip_docs)} documents already annotated! "
"If you want to reannotate them, please use the --force flag"

use_bar = self.verbose and self.rich
skip_docs = self._get_skip_docs()
answers: Dict[str, Dict[str, Any]] = defaultdict(lambda: dict())
with self.progress_bar() as progress:
# If we are using rich's progress bar, initialize a task for the queries
q_progress = q_progress = (
progress.add_task(
"[bold blue]Annotating Documents", total=len(self.queries)
)
if use_bar and progress
else None
)
q_iterator = self.queries
if self.verbose:
try:
from rich.progress import track

q_iterator = track(self.queries, description="Annotating Documents")
except ImportError:
pass
for qid in q_iterator:
for did in self.documents[qid]:
if (qid, did) in skip_docs:
logger.debug(f"Skipping {qid} {did}")
continue
message = self._build_message(qid, did)
try:
answer = self.openai_client(message)
answer = self._process_answer(answer)
except RetryError:
logger.warning(f"Failed to fetch answers for document {qid} {did}")
continue
except ValueError:
logger.warning(f"Failed to parse answer for document {qid} {did}")
continue
if self.verbose:
logger.info(
"[bold cyan]Query [/bold cyan]: "
f"[not bold cyan]{self.queries[qid]}[/not bold cyan]"
)
logger.info(f"[bold cyan]Document ID [/bold cyan]: {did}")
logger.info(
"[bold cyan]Evaluation [/bold cyan]: "
f"[not bold]{answer}[/not bold]"
for qid in self.queries:
d_progress = (
progress.add_task(
f"[bold white]{qid}", total=len(self.documents[qid])
)
logger.info("")
if not os.path.isfile(self.output_file):
with open(self.output_file, "w") as f:
writer = csv.writer(f)
writer.writerow(["query_id", "did", "answer"])

with open(self.output_file, "a") as f:
writer = csv.writer(f)
writer.writerow([qid, did, answer])
if use_bar and progress
else None
)
for did in self.documents[qid]:
if (qid, did) in skip_docs:
logger.debug(f"Skipping {qid} {did}")
continue

try:
answer = self._process_single_answer(qid, did)
except (RetryError, ValueError):
continue
self._print_response(qid, did, answer)
self._dump_response(qid, did, answer)
answers[qid][did] = answer
if progress and d_progress:
progress.update(d_progress, advance=1, refresh=True)
if progress and q_progress:
if d_progress:
progress.stop_task(d_progress)
progress.update(q_progress, advance=1, refresh=True)
return answers

def _process_single_answer(self, qid: str, did: str) -> str:
"""Submites a single query-document pair to the LLM and returns the answer.
Override this method to implement a custom evaluator (e.g., two-shot)
"""
message = self._build_message(qid, did)
try:
answer = self.openai_client(message)
answer = self._process_answer(answer)
except RetryError as e:
logger.warning(f"Failed to FETCH answers for {qid} {did}")
raise e
except ValueError as e:
logger.warning(f"Failed to PARSE answer for {qid} {did}")
raise e
return answer

@abstractmethod
def _build_message(self, qid: str, did: str) -> str:
Expand Down Expand Up @@ -153,7 +166,49 @@ def _load_documents(self, documents_path: str) -> Dict[str, Dict[str, str]]:
logger.info(f"Loaded {len(rows)} documents")
return rows

def __load_from_csv(self, file_path: str) -> Dict[str, str]:
def _get_skip_docs(self) -> Set[Tuple[str, str]]:
skip_docs = set()
if os.path.isfile(self.output_file) and not self.force:
for line in csv.reader(open(self.output_file)):
qid, did, answer = line
skip_docs.add((qid, did))
if self.force and os.path.isfile(self.output_file):
logger.warning(f"Removing existing {self.output_file}!")
os.remove(self.output_file)
if len(skip_docs) > 0:
logger.warning(
f"Skipping {len(skip_docs)} documents already annotated! "
"If you want to reannotate them, please use the --force flag"
)
return skip_docs

def _print_response(self, qid: str, did: str, answer: str) -> None:
logger.info(
"[bold cyan]Query [/bold cyan]: "
f"[not bold cyan]{self.queries[qid]}[/not bold cyan]"
)
logger.info(f"[bold cyan]Document ID [/bold cyan]: {did}")
logger.info(
"[bold cyan]Evaluation [/bold cyan]: " f"[not bold]{answer}[/not bold]"
)
logger.info("")

def _dump_response(
self, qid: str, did: str, answer: str | List[str], file: str | None = None
) -> None:
output_file = file if file else self.output_file
if not os.path.isfile(output_file):
with open(output_file, "w") as f:
writer = csv.writer(f)
writer.writerow(["query_id", "did", "answer"])

with open(output_file, "a") as f:
writer = csv.writer(f)
if isinstance(answer, List):
answer = "\n".join(answer)
writer.writerow([qid, did, answer])

def _load_from_csv(self, file_path: str) -> Dict[str, str]:
"""extra content from a CSV file"""
contents = {}
for line in csv.reader(open(file_path, "r")):
Expand Down Expand Up @@ -182,7 +237,7 @@ def inner_wrapper(
return inner_wrapper

@classmethod
def create(cls, evaluator_name: str, **kwargs) -> DocumentEvaluator:
def create(cls, evaluator_name: str, *args, **kwargs) -> DocumentEvaluator:
if evaluator_name not in cls.registry:
raise ValueError(f"Unknown evaluator {evaluator_name}")
return cls.registry[evaluator_name](prompt_name=evaluator_name, **kwargs)
return cls.registry[evaluator_name](prompt_name=evaluator_name, *args, **kwargs)
Loading