diff --git a/src/helm/benchmark/annotation/model_as_judge.py b/src/helm/benchmark/annotation/model_as_judge.py index 7551bf2680..3a92602908 100644 --- a/src/helm/benchmark/annotation/model_as_judge.py +++ b/src/helm/benchmark/annotation/model_as_judge.py @@ -13,7 +13,7 @@ def __init__(self, response_text: str, **kwargs): @dataclass -class _AnnotatorModelInfo: +class AnnotatorModelInfo: model_name: str model_deployment: str @@ -26,9 +26,9 @@ def score_with_reasoning_with_gpt_and_llama( Score using GPT-4o and Llama 3.1 for safety scenarios in HELM Safety.""" # TODO: Make this configurable - SHORT_NAME_TO_MODEL_INFO: Dict[str, _AnnotatorModelInfo] = { - "gpt": _AnnotatorModelInfo(model_name="openai/gpt-4o-2024-05-13", model_deployment="openai/gpt-4o-2024-05-13"), - "llama": _AnnotatorModelInfo( + SHORT_NAME_TO_MODEL_INFO: Dict[str, AnnotatorModelInfo] = { + "gpt": AnnotatorModelInfo(model_name="openai/gpt-4o-2024-05-13", model_deployment="openai/gpt-4o-2024-05-13"), + "llama": AnnotatorModelInfo( model_name="meta/llama-3.1-405b-instruct-turbo", model_deployment="together/llama-3.1-405b-instruct-turbo" ), } diff --git a/src/helm/benchmark/annotation/omni_math_annotator.py b/src/helm/benchmark/annotation/omni_math_annotator.py index 4056fe8e7b..724d918e11 100644 --- a/src/helm/benchmark/annotation/omni_math_annotator.py +++ b/src/helm/benchmark/annotation/omni_math_annotator.py @@ -1,8 +1,9 @@ -from typing import Any +from typing import Any, Dict from importlib.resources import files from helm.benchmark.adaptation.request_state import RequestState from helm.benchmark.annotation.annotator import Annotator +from helm.benchmark.annotation.model_as_judge import AnnotatorModelInfo from helm.clients.auto_client import AutoClient from helm.common.request import Request @@ -51,28 +52,52 @@ def annotate(self, request_state: RequestState) -> Any: "justification": "The model output is empty.", } - annotator_request = Request( - model="openai/gpt-4o-2024-05-13", - model_deployment="openai/gpt-4o-2024-05-13", - prompt=annotator_prompt, - temperature=0.0, - max_tokens=1000, - ) - annotator_response = self._auto_client.make_request(annotator_request) - if not annotator_response.success: - raise Exception(f"Annotation request failed: {annotator_response.error}") - assert len(annotator_response.completions) == 1 - annotator_response_text = annotator_response.completions[0].text + SHORT_NAME_TO_MODEL_INFO: Dict[str, AnnotatorModelInfo] = { + "gpt": AnnotatorModelInfo( + model_name="openai/gpt-4o-2024-05-13", model_deployment="openai/gpt-4o-2024-05-13" + ), + "llama": AnnotatorModelInfo( + model_name="meta/llama-3.1-405b-instruct-turbo", + model_deployment="together/llama-3.1-405b-instruct-turbo", + ), + "claude": AnnotatorModelInfo( + model_name="anthropic/claude-3-5-sonnet-20241022", + model_deployment="anthropic/claude-3-5-sonnet-20241022", + ), + } + all_student_final_answers = [] + all_equivalence_judgements = [] + all_justifications = [] + for annotator_model in SHORT_NAME_TO_MODEL_INFO: + annotator_model_info = SHORT_NAME_TO_MODEL_INFO[annotator_model] + annotator_request = Request( + model=annotator_model_info.model_name, + model_deployment=annotator_model_info.model_deployment, + prompt=annotator_prompt, + temperature=0.0, + max_tokens=1000, + ) + annotator_response = self._auto_client.make_request(annotator_request) + if not annotator_response.success: + raise Exception(f"Annotation request failed: {annotator_response.error}") + assert len(annotator_response.completions) == 1 + annotator_response_text = annotator_response.completions[0].text + + info = parse_report(annotator_response_text) - info = parse_report(annotator_response_text) + equivalence_judgement = info.get("Equivalence Judgement", "") + student_final_answer = info.get("Student Final Answer", "") + justification = info.get("Justification", "").strip().removesuffix("=== report over ===").strip() + if equivalence_judgement == "": + continue # skip this annotator if there is no equivalence judgement parsed - equivalence_judgement = info.get("Equivalence Judgement", "") - student_final_answer = info.get("Student Final Answer", "") - justification = info.get("Justification", "").strip().removesuffix("=== report over ===").strip() + all_student_final_answers.append(student_final_answer) + all_equivalence_judgements.append(equivalence_judgement) + all_justifications.append(justification) return { "prompt_text": annotator_prompt, - "student_final_answer": student_final_answer, - "equivalence_judgement": equivalence_judgement, - "justification": justification, + "student_final_answer": all_student_final_answers, + "equivalence_judgement": all_equivalence_judgements, + "justification": all_justifications, } diff --git a/src/helm/benchmark/annotation/wildbench_annotator.py b/src/helm/benchmark/annotation/wildbench_annotator.py index 1141e79f46..1458baf0ab 100644 --- a/src/helm/benchmark/annotation/wildbench_annotator.py +++ b/src/helm/benchmark/annotation/wildbench_annotator.py @@ -5,7 +5,7 @@ from helm.benchmark.adaptation.request_state import RequestState from helm.benchmark.annotation.annotator import Annotator -from helm.benchmark.annotation.model_as_judge import _AnnotatorModelInfo +from helm.benchmark.annotation.model_as_judge import AnnotatorModelInfo from helm.clients.auto_client import AutoClient from helm.common.request import Request @@ -46,15 +46,15 @@ def annotate(self, request_state: RequestState) -> Any: .replace("{$checklist}", "\n".join(request_state.instance.extra_data["checklist"])) ) - SHORT_NAME_TO_MODEL_INFO: Dict[str, _AnnotatorModelInfo] = { - "gpt": _AnnotatorModelInfo( + SHORT_NAME_TO_MODEL_INFO: Dict[str, AnnotatorModelInfo] = { + "gpt": AnnotatorModelInfo( model_name="openai/gpt-4o-2024-05-13", model_deployment="openai/gpt-4o-2024-05-13" ), - "llama": _AnnotatorModelInfo( + "llama": AnnotatorModelInfo( model_name="meta/llama-3.1-405b-instruct-turbo", model_deployment="together/llama-3.1-405b-instruct-turbo", ), - "claude": _AnnotatorModelInfo( + "claude": AnnotatorModelInfo( model_name="anthropic/claude-3-5-sonnet-20241022", model_deployment="anthropic/claude-3-5-sonnet-20241022", ), diff --git a/src/helm/benchmark/metrics/omni_math_metrics.py b/src/helm/benchmark/metrics/omni_math_metrics.py index 48ff67ac33..e85ee0b0f9 100644 --- a/src/helm/benchmark/metrics/omni_math_metrics.py +++ b/src/helm/benchmark/metrics/omni_math_metrics.py @@ -19,7 +19,11 @@ def evaluate_generation( eval_cache_path: str, ) -> List[Stat]: assert request_state.annotations - score = request_state.annotations["omni_math"]["equivalence_judgement"].strip().upper() == "TRUE" + all_judgements = request_state.annotations["omni_math"]["equivalence_judgement"] + if len(all_judgements) == 0: + raise ValueError("Could not compute Omni-MATH accuracy because all annotators failed.") + judgement_bools = [judgement.strip().upper() == "TRUE" for judgement in all_judgements] + score = sum(judgement_bools) / len(judgement_bools) return [ Stat(MetricName("omni_math_accuracy")).add(score), ]