diff --git a/run.py b/run.py index 71377999e..c413fadb6 100644 --- a/run.py +++ b/run.py @@ -1,7 +1,7 @@ import torch import torch.distributed as dist from vlmeval.smp import * -from vlmeval.evaluate import COCO_eval, MME_eval, MMVet_eval, multiple_choice_eval, MME_rating, VQAEval +from vlmeval.evaluate import COCO_eval, MME_eval, MMVet_eval, multiple_choice_eval, MME_rating, VQAEval, MathVista_eval from vlmeval.inference import infer_data_job, prefetch_acc from vlmeval.config import supported_VLM from vlmeval.utils import dataset_URLs, abbr2full @@ -82,6 +82,8 @@ def main(): COCO_eval(result_file) elif listinstr(['OCRVQA', 'TextVQA', 'ChartQA'], dataset_name): VQAEval(result_file) + elif listinstr(['MathVista'], dataset_name): + MathVista_eval(result_file, model='gpt-4-turbo', nproc=args.nproc, verbose=args.verbose) else: logger.error(f'Dataset {dataset_name} is not handled by evaluator, will be skipped. ') diff --git a/vlmeval/evaluate/__init__.py b/vlmeval/evaluate/__init__.py index 03a3bfb98..fc4d22688 100644 --- a/vlmeval/evaluate/__init__.py +++ b/vlmeval/evaluate/__init__.py @@ -2,4 +2,5 @@ from .mmvet_eval import MMVet_eval from .multiple_choice import multiple_choice_eval from .coco_eval import COCO_eval -from .vqa_eval import VQAEval \ No newline at end of file +from .vqa_eval import VQAEval +from .mathvista_eval import MathVista_eval \ No newline at end of file diff --git a/vlmeval/evaluate/mathvista_eval.py b/vlmeval/evaluate/mathvista_eval.py new file mode 100644 index 000000000..385b70430 --- /dev/null +++ b/vlmeval/evaluate/mathvista_eval.py @@ -0,0 +1,233 @@ +from vlmeval.api import OpenAIWrapper, OpenAIWrapperInternal +from vlmeval.smp import * +from vlmeval.utils import track_progress_rich +from vlmeval.utils.matching_util import can_infer + +INTERNAL = os.environ.get('INTERNAL', 0) + +def get_gpt4_ICE(): + example_1 = """ + Hint: Please answer the question requiring an integer answer and provide the final value, e.g., 1, 2, 3, at the end.\n + Question: Which number is missing?\n + Model response: The number missing in the sequence is 14.\n + Extracted answer: 14 + """ + + example_2 = """ + Hint: Please answer the question requiring a floating-point number with one decimal place and provide the final value, e.g., 1.2, 1.3, 1.4, at the end.\n + Question: What is the fraction of females facing the camera?\n + Model response: The fraction of females facing the camera is 0.6, which means that six out of ten females in the group are facing the camera.\n + Extracted answer: 0.6 + """ + + example_3 = """ + Hint: Please answer the question requiring a floating-point number with two decimal places and provide the final value, e.g., 1.23, 1.34, 1.45, at the end.\n + Question: How much money does Luca need to buy a sour apple candy and a butter-scotch candy? (Unit: $)\n + Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy.\n + Extracted answer: 1.45 + """ + + example_4 = """ + Hint: Please answer the question requiring a Python list as an answer and provide the final list, e.g., [1, 2, 3], [1.2, 1.3, 1.4], at the end.\n + Question: Between which two years does the line graph saw its maximum peak?\n + Model response: The line graph saw its maximum peak between 2007 and 2008.\n + Extracted answer: [2007, 2008] + """ + + example_5 = """ + Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.\n + Question: What fraction of the shape is blue?\n + Choices: (A) 3/11 (B) 8/11 (C) 6/11 (D) 3/5\n + Model response: The correct answer is (B) 8/11.\n + Extracted answer: B + """ + return [example_1,example_2,example_3,example_4,example_5] + + +def build_mathvista_gpt4_prompt(line): + task_description = """ Please read the following example. Then extract the answer from the model response and type it at the end of the prompt.\n""" + question = line['question'] + prediction = str(line['prediction']) + prompt = task_description + examples = get_gpt4_ICE() + for example in examples: + prompt += example + '\n' + prompt += question + '\n' + prompt += 'Model respone: ' + prediction + prompt += 'Extracted answer:' + return prompt + +def list_to_dict(lst): + return {chr(65 + i): val for i, val in enumerate(lst)} + +def post_check(line, prefetch=False): + res = None + ans = line['answer'] + response = line['prediction'] if prefetch else line['res'] + try: + if line['question_type'] == 'multi_choice': + ans = line['answer_option'] + choices = list_to_dict(eval(line['choices'])) + res = can_infer(response, choices) + if prefetch: + return res + else: + if line['answer_type'] == 'integer': + res = int(response) + ans = int(line['answer']) + elif line['answer_type'] == 'float': + res = float(response) + ans = float(line['answer']) + else: + res = str(res) + ans = str(ans) + except ValueError: + pass + + if res == ans: + return res + else: + return False + +def MathVista_auxeval(model, line): + prompt = build_mathvista_gpt4_prompt(line) + log = '' + retry = 5 + if post_check(line, prefetch=True): + res = post_check(line, prefetch=True) + return dict(log='Prefetch succeed', res=res) + for i in range(retry): + prediction = line['prediction'] + res = model.generate(prompt, temperature=i * 0.5) + if res is None: + log += f'Try {i}: output is {prediction}, failed to parse.\n' + else: + log += 'Succeed' + return dict(log=log, res= res) + log += 'All 5 retries failed.\n' + return dict(log=log, res='') + +def MathVista_acc(result_file): + data = load(result_file) + tot = defaultdict(lambda: 0) + fetch = defaultdict(lambda: 0) + hit = defaultdict(lambda: 0) + lt = len(data) + skill_list = [] + for i in range(lt): + item = data.iloc[i] + index = item['index'] + cate = item['task'] + tot['Overall'] += 1 + try: + skills = eval(item['skills']) + except SyntaxError: + skills = [item['skills']] + for skill in skills: + if skill not in skill_list: + skill_list.append(skill) + tot[skill] += 1 + tot[cate] += 1 + if item['log'] == 'Prefetch succeed': + fetch['Overall'] += 1 + fetch[cate] += 1 + for skill in skills: + fetch[skill] += 1 + if post_check(item, prefetch=False): + hit['Overall'] += 1 + hit[cate] += 1 + for skill in skills: + hit[skill] += 1 + + res = defaultdict(list) + for k in tot.keys(): + res['Task&Skill'].append(k) + res['tot'].append(tot[k]) + res['prefetch'].append(fetch[k]) + res['hit'].append(hit[k]) + res['prefetch_rate'].append(fetch[k] / tot[k] * 100) + res['acc'].append(hit[k] / tot[k] * 100) + res = pd.DataFrame(res) + return res + +def MathVista_eval(eval_file, model='gpt-4-turbo', nproc=4, verbose=False): + logger = get_logger('Evaluation') + + suffix = eval_file.split('.')[-1] + storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx') + tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl') + if osp.exists(storage): + logger.warning(f"GPT scoring file {storage} already exists, will reuse it in MathVista_eval. ") + else: + data = load(eval_file) + gpt_version = model + + model_map = { + 'gpt-4-turbo': 'gpt-4-1106-preview', + 'gpt-4-0613': 'gpt-4-0613', + 'chatgpt-1106': 'gpt-3.5-turbo-1106', + 'chatgpt-0613': 'gpt-3.5-turbo-0613' + } + model_version = model_map[gpt_version] + + if INTERNAL: + # We follow the original codebase to set max_tokens == 3 + model = OpenAIWrapperInternal(model_version, verbose=verbose, max_tokens=128, retry=10) + else: + model = OpenAIWrapper(model_version, verbose=verbose, max_tokens=128, retry=10) + + lt = len(data) + lines = [data.iloc[i] for i in range(lt)] + tups = [(model, line) for line in lines] + indices = [line['index'] for line in lines] + + ans = {} + if osp.exists(tmp_file): + ans = load(tmp_file) + tups = [x for x, i in zip(tups, indices) if i not in ans] + indices = [i for i in indices if i not in ans] + + if len(indices): + new_results = track_progress_rich( + MathVista_auxeval, tups, nproc=nproc, chunksize=nproc, + keys=indices, save=tmp_file) + ans = load(tmp_file) + for k, v in zip(indices, new_results): + assert k in ans + assert ans[k]['log'] == v['log'] and ans[k]['res'] == v['res'] + + log_map, res_map = {}, {} + all_inds = [line['index'] for line in lines] + for k in all_inds: + log_map[k] = ans[k]['log'] + res_map[k] = ans[k]['res'] + data['res'] = [res_map[idx] for idx in data['index']] + data['log'] = [log_map[idx] for idx in data['index']] + dump(data, storage) + + score = MathVista_acc(storage) + score_pth = storage.replace('.xlsx','_score.csv') + + dump(score,score_pth) + logger.info(f'MathVista_eval successfully finished evaluating {eval_file}, results saved in {score_pth}') + logger.info(f'Score: ') + logger.info(score) + +def parse_args(): + parser = argparse.ArgumentParser(description="Inference LLM Answers. ") + parser.add_argument("data", type=str, help="The question set for inference, in excel / tsv / json format. ") + parser.add_argument( + "--model", + type=str, + help="The LLM (GPT) used for inference. ", + default="gpt-4-turbo", + choices=['gpt-4-0613', 'gpt-4-turbo', 'chatgpt-1106', 'chatgpt-0613']) + parser.add_argument("--nproc", type=int, default=4) + parser.add_argument("--verbose", action='store_true') + args = parser.parse_args() + return args + +if __name__ == '__main__': + args = parse_args() + MathVista_eval(eval_file=args.data, model=args.model, nproc=args.nproc, verbose=args.verbose) + diff --git a/vlmeval/inference.py b/vlmeval/inference.py index ae568f193..7d67a0217 100644 --- a/vlmeval/inference.py +++ b/vlmeval/inference.py @@ -249,7 +249,7 @@ def main(): model = model_name # which is only a name model = infer_data_job(model, model_name=model_name, dataset_name=dataset_name, verbose=args.verbose, api_nproc=args.nproc) - if rank == 0 and not listinstr(['MME', 'CORE_MM', 'MMVet', 'COCO', 'MMMU'], dataset_name): + if rank == 0 and not listinstr(['MME', 'CORE_MM', 'MMVet', 'COCO', 'MMMU', 'MathVista'], dataset_name): time.sleep(3) res = prefetch_acc(result_file) print(model_name, res) diff --git a/vlmeval/utils/dataset_config.py b/vlmeval/utils/dataset_config.py index 8cd9c4c10..58a6cdb75 100644 --- a/vlmeval/utils/dataset_config.py +++ b/vlmeval/utils/dataset_config.py @@ -17,6 +17,7 @@ "OCRVQA_TESTCORE": "https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TESTCORE.tsv", 'TextVQA_VAL': "https://opencompass.openxlab.space/utils/VLMEval/TextVQA_VAL.tsv", "MMMU_DEV_VAL": "https://opencompass.openxlab.space/utils/VLMEval/MMMU_DEV_VAL.tsv", + "MathVista_MINI": "https://opencompass.openxlab.space/utils/VLMEval/MathVista_MINI.tsv", 'ChartQA_VALTEST_HUMAN': "https://opencompass.openxlab.space/utils/VLMEval/ChartQA_VALTEST_HUMAN.tsv", 'ScienceQA_VAL': "https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_VAL.tsv", 'ScienceQA_TEST': "https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_TEST.tsv" @@ -39,6 +40,7 @@ 'OCRVQA_TESTCORE': 'c5239fe77db8bdc1f2ad8e55e0d1fe97', 'TextVQA_VAL': 'b233b31f551bbf4056f2f955da3a92cd', 'MMMU_DEV_VAL': "501f84dc642a9b17e35363b78c0191e1", + 'MathVista_MINI': 'f199b98e178e5a2a20e7048f5dcb0464', 'ChartQA_VALTEST_HUMAN':'2c90a4133408a21d57fb2ea26f77bbfc', 'ScienceQA_VAL': '96320d05e142e585e7204e72affd29f3', 'ScienceQA_TEST': 'e42e9e00f9c59a80d8a5db35bc32b71f' @@ -61,6 +63,7 @@ 'OCRVQA_TESTCORE': 'OCRVQA', 'TextVQA_VAL': 'TextVQA', 'MMMU_DEV_VAL': 'MMMU', + 'MathVista_MINI': 'MathVista', 'ChartQA_VALTEST_HUMAN': 'ChartQA', 'ScienceQA_VAL': 'ScienceQA', 'ScienceQA_TEST': 'ScienceQA' @@ -75,7 +78,7 @@ def DATASET_TYPE(dataset): return 'Y/N' elif 'COCO' in dataset: return 'Caption' - elif listinstr(['ocrvqa', 'textvqa', 'chartqa'], dataset.lower()): + elif listinstr(['ocrvqa', 'textvqa', 'chartqa', 'mathvista'], dataset.lower()): return 'VQA' return None