-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 5aac08b
Showing
9 changed files
with
5,119 additions
and
0 deletions.
There are no files selected for viewing
110 changes: 110 additions & 0 deletions
110
GPT_eval/detailed_evaluation/difficulty_specific_evaluation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import json | ||
import pdb | ||
import argparse | ||
|
||
def main(args): | ||
data = [] | ||
with open(args.input_file) as f: | ||
for line in f.readlines(): | ||
data.append(json.loads(line)) | ||
|
||
def parse_report(report): | ||
parts = report.split("## ") | ||
data = {} | ||
|
||
for part in parts[1:]: # 从第一个部分开始 | ||
lines = part.strip().split("\n") | ||
title = lines[0].strip() # 第一行是标题 | ||
content = "\n".join(lines[1:]).strip() # 剩余的内容合并 | ||
|
||
if title == "Justification": | ||
# Justification 可能有多行,直接存储所有内容 | ||
data[title] = content | ||
else: | ||
# 只取第一行的内容 | ||
data[title] = lines[1].strip() if len(lines) > 1 else '' | ||
|
||
return data | ||
|
||
difficulty_1_3 = [] | ||
difficulty_3_5 = [] | ||
difficulty_5_7 = [] | ||
difficulty_7_10 = [] | ||
|
||
|
||
target_entry = [] | ||
for entry in data: | ||
original_json = json.loads(entry['original_json']) | ||
gpt4_eval = entry['gen'] | ||
info = parse_report(gpt4_eval) | ||
if info == {}: | ||
continue | ||
try: | ||
correctness = info['Equivalence Judgement'] | ||
if correctness == 'TRUE': | ||
target_entry.append({'source': original_json['source'], 'domain': original_json['domain'], 'difficulty': original_json['difficulty'], 'problem': original_json['problem'], 'answer': original_json['answer'], 'model_generation': original_json['model_generation'], 'correctness': True}) | ||
else: | ||
target_entry.append({'source': original_json['source'], 'domain': original_json['domain'], 'difficulty': original_json['difficulty'], 'problem': original_json['problem'], 'answer': original_json['answer'], 'model_generation': original_json['model_generation'], 'correctness': False}) | ||
except: | ||
continue | ||
|
||
for line in target_entry: | ||
if line['difficulty'] >= 1 and line['difficulty'] <= 3: | ||
difficulty_1_3.append(line) | ||
elif line['difficulty'] > 3 and line['difficulty'] <= 5: | ||
difficulty_3_5.append(line) | ||
elif line['difficulty'] > 5 and line['difficulty'] <= 8: | ||
difficulty_5_7.append(line) | ||
elif line['difficulty'] > 8 and line['difficulty'] <= 10: | ||
difficulty_7_10.append(line) | ||
|
||
|
||
cnt = 0 | ||
acc = 0 | ||
for line in difficulty_1_3: | ||
if line['correctness'] == True: | ||
acc += 1 | ||
cnt += 1 | ||
if cnt != 0: | ||
print('difficulty_1_3: {}'.format(acc / cnt)) | ||
|
||
cnt = 0 | ||
acc = 0 | ||
for line in difficulty_3_5: | ||
if line['correctness'] == True: | ||
acc += 1 | ||
cnt += 1 | ||
|
||
if cnt != 0: | ||
print('difficulty_3_5: {}'.format(acc / cnt)) | ||
|
||
cnt = 0 | ||
acc = 0 | ||
for line in difficulty_5_7: | ||
if line['correctness'] == True: | ||
acc += 1 | ||
cnt += 1 | ||
|
||
print('difficulty_5_7: {}'.format(acc / cnt)) | ||
|
||
cnt = 0 | ||
acc = 0 | ||
for line in difficulty_7_10: | ||
if 'prove' in line['problem'].lower() or 'proof' in line['problem'].lower(): | ||
continue | ||
if line['correctness'] == True: | ||
acc += 1 | ||
cnt += 1 | ||
|
||
if cnt != 0: | ||
print('difficulty_7_10: {}'.format(acc / cnt)) | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--input_file", type=str) # input path | ||
return parser.parse_args() | ||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
main(args) |
144 changes: 144 additions & 0 deletions
144
GPT_eval/detailed_evaluation/domain_specific_evaluation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
import json | ||
import pdb | ||
import argparse | ||
from prettytable import PrettyTable | ||
from collections import defaultdict | ||
|
||
def update_domain(target_entry): | ||
cnt = 0 | ||
data_1 = [] | ||
with open('/mnt/workspace/hard_math/OmniPic_Bench/OmniPic_Bench_v1_final.jsonl') as f: | ||
for line in f.readlines(): | ||
data_1.append(json.loads(line)) | ||
|
||
# 创建一个字典以便快速查找第一份数据 | ||
first_dict = {item['problem']: item['domain'] for item in data_1} | ||
|
||
# 更新第二份数据中的 domain | ||
for item in target_entry: | ||
problem = item['problem'] | ||
if problem in first_dict: | ||
item['domain'] = first_dict[problem] | ||
else: | ||
continue | ||
return target_entry | ||
|
||
def parse_report(report): | ||
parts = report.split("## ") | ||
data = {} | ||
|
||
for part in parts[1:]: # 从第一个部分开始 | ||
lines = part.strip().split("\n") | ||
title = lines[0].strip() # 第一行是标题 | ||
content = "\n".join(lines[1:]).strip() # 剩余的内容合并 | ||
|
||
if title == "Justification": | ||
# Justification 可能有多行,直接存储所有内容 | ||
data[title] = content | ||
else: | ||
# 只取第一行的内容 | ||
data[title] = lines[1].strip() if len(lines) > 1 else '' | ||
|
||
return data | ||
|
||
def main(args): | ||
data = [] | ||
with open(args.input_file) as f: | ||
for line in f.readlines(): | ||
data.append(json.loads(line)) | ||
|
||
# 处理 JSON 数据 | ||
target_entry = [] | ||
for entry in data: | ||
original_json = json.loads(entry['original_json']) | ||
gpt4_eval = entry['gen'] | ||
info = parse_report(gpt4_eval) | ||
if info == {}: | ||
continue | ||
try: | ||
correctness = info['Equivalence Judgement'] | ||
if correctness == 'TRUE': | ||
target_entry.append({'source': original_json['source'], 'domain': original_json['domain'], 'difficulty': original_json['difficulty'], 'problem': original_json['problem'], 'answer': original_json['answer'], 'model_generation': original_json['model_generation'], 'correctness': True}) | ||
else: | ||
target_entry.append({'source': original_json['source'], 'domain': original_json['domain'], 'difficulty': original_json['difficulty'], 'problem': original_json['problem'], 'answer': original_json['answer'], 'model_generation': original_json['model_generation'], 'correctness': False}) | ||
except: | ||
continue | ||
|
||
target_entry = update_domain(target_entry) | ||
|
||
domain_acc_dict = {} | ||
for line in target_entry: | ||
for domain_chain in line['domain']: | ||
if domain_chain in domain_acc_dict: | ||
a, t = domain_acc_dict[domain_chain] | ||
if line['correctness'] == True: | ||
domain_acc_dict[domain_chain] = (a+1, t+1) | ||
else: | ||
domain_acc_dict[domain_chain] = (a, t+1) | ||
else: | ||
if line['correctness'] == True: | ||
domain_acc_dict[domain_chain] = (1, 1) | ||
else: | ||
domain_acc_dict[domain_chain] = (0, 1) | ||
|
||
domain_summary = {} | ||
|
||
# 解析源数据 | ||
for key, (correct, total) in domain_acc_dict.items(): | ||
parts = key.split(" -> ") | ||
|
||
# 递归更新准确率 | ||
current = domain_summary | ||
for part in parts: | ||
if part not in current.keys(): | ||
current[part] = {'correct': correct, 'total': total} | ||
else: | ||
current[part]['correct'] += correct | ||
current[part]['total'] += total | ||
|
||
current = current[part] | ||
|
||
# 准备PrettyTable | ||
table = PrettyTable() | ||
table.field_names = ["Domain", "Accuracy (Positive / Total)"] | ||
|
||
# 计算主领域的准确率 | ||
main_correct = domain_summary['Mathematics']['correct'] | ||
main_total = domain_summary['Mathematics']['total'] | ||
|
||
# 计算主领域的准确率并加入表格 | ||
if main_total > 0: | ||
main_accuracy = (main_correct / main_total) * 100 | ||
else: | ||
main_accuracy = 0 | ||
|
||
# 子领域信息 | ||
for sub_domain, values in domain_summary['Mathematics'].items(): | ||
if sub_domain in ['correct', 'total']: | ||
continue # 跳过主领域的正确率和总数 | ||
|
||
sub_correct = values['correct'] | ||
sub_total = values['total'] | ||
|
||
if sub_total > 0: | ||
sub_accuracy = (sub_correct / sub_total) * 100 | ||
else: | ||
sub_accuracy = 0 | ||
|
||
# 子领域信息 | ||
sub_accuracy_info = f"{sub_accuracy:.2f}% ({sub_correct}/{sub_total})" | ||
|
||
# 添加到表格的子领域的准确率信息 | ||
table.add_row([sub_domain, sub_accuracy_info]) | ||
|
||
# 显示表格 | ||
print(table) | ||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--input_file", type=str) # input path | ||
return parser.parse_args() | ||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
main(args) |
100 changes: 100 additions & 0 deletions
100
GPT_eval/examples/internlm2-math-plus-mixtral8x22b_gpteval.jsonl
Large diffs are not rendered by default.
Oops, something went wrong.
100 changes: 100 additions & 0 deletions
100
GPT_eval/examples/meta_llama_3-1_70b_instruct_gpteval.jsonl
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import json | ||
import pdb | ||
import argparse # 只是将jsonl转为了list of dict | ||
from collections import defaultdict | ||
def parse_report(report): | ||
parts = report.split("## ") | ||
data = {} | ||
|
||
for part in parts[1:]: # 从第一个部分开始 | ||
lines = part.strip().split("\n") | ||
title = lines[0].strip() # 第一行是标题 | ||
content = "\n".join(lines[1:]).strip() # 剩余的内容合并 | ||
|
||
if title == "Justification": | ||
# Justification 可能有多行,直接存储所有内容 | ||
data[title] = content | ||
else: | ||
# 只取第一行的内容 | ||
data[title] = lines[1].strip() if len(lines) > 1 else '' | ||
|
||
return data | ||
|
||
def main(args): | ||
data = [] | ||
with open(args.in_file) as f: | ||
for line in f.readlines(): | ||
data.append(json.loads(line)) | ||
|
||
# 处理 JSON 数据 | ||
target_entry = [] | ||
for entry in data: | ||
original_json = json.loads(entry['original_json']) | ||
gpt4_eval = entry['gen'] | ||
info = parse_report(gpt4_eval) | ||
if info == {}: | ||
continue | ||
try: | ||
correctness = info['Equivalence Judgement'] | ||
if correctness == 'TRUE': | ||
target_entry.append({'source': original_json['source'], 'domain': original_json['domain'], 'difficulty': original_json['difficulty'], 'question': original_json['problem'], 'answer': original_json['answer'], 'model_generation': original_json['model_generation'], 'correctness': True}) | ||
else: | ||
target_entry.append({'source': original_json['source'], 'domain': original_json['domain'], 'difficulty': original_json['difficulty'], 'question': original_json['problem'], 'answer': original_json['answer'], 'model_generation': original_json['model_generation'], 'correctness': False}) | ||
except: | ||
continue | ||
|
||
|
||
# 读取jsonl文件并将数据按difficulty分组 | ||
grouped_data = defaultdict(list) | ||
for line in target_entry: | ||
difficulty = line['difficulty'] | ||
correctness = line['correctness'] | ||
grouped_data[difficulty].append(correctness) | ||
|
||
# 计算每个difficulty组的准确率 | ||
accuracy_by_difficulty = {} | ||
tot_acc = 0 | ||
tot_len = 0 | ||
for difficulty, correctness_list in grouped_data.items(): | ||
total_questions = len(correctness_list) | ||
tot_len += total_questions | ||
correct_answers = correctness_list.count(True) | ||
tot_acc += correct_answers | ||
accuracy = correct_answers / total_questions if total_questions > 0 else 0 | ||
if len(correctness_list) > 10: | ||
accuracy_by_difficulty[difficulty] = accuracy | ||
|
||
print('Total Accuracy:{}'.format(tot_acc / tot_len)) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description="llm gen") | ||
parser.add_argument("-i", "--in-file", type=str) | ||
parser.add_argument("-o", "--out-file", type=str) | ||
args = parser.parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
INFILE="/cpfs01/shared/public/gaobofei.gbf/hard_math/Omni-MATH/GPT_eval/examples/meta_llama_3-1_70b_instruct_gpteval.jsonl" | ||
|
||
python get_result.py \ | ||
-i $INFILE \ | ||
|
||
python3 ./detailed_evaluation/domain_specific_evaluation.py \ | ||
--input_file $INFILE \ | ||
|
||
python3 ./detailed_evaluation/difficulty_specific_evaluation.py \ | ||
--input_file $INFILE \ |
Oops, something went wrong.