Skip to content

Commit

Permalink
Add Spider 1.0 scenario (#3300)
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanmai authored Feb 4, 2025
1 parent 6fb429e commit 5a50569
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 13 deletions.
22 changes: 21 additions & 1 deletion src/helm/benchmark/run_specs/sql_run_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


@run_spec_function("bird_sql")
def get_bird_sql_dev() -> RunSpec:
def get_bird_sql_dev_run_spec() -> RunSpec:
scenario_spec = ScenarioSpec(class_name="helm.benchmark.scenarios.bird_sql_scenario.BIRDSQLScenario")

adapter_spec = get_generation_adapter_spec(
Expand All @@ -22,3 +22,23 @@ def get_bird_sql_dev() -> RunSpec:
metric_specs=get_exact_match_metric_specs(),
groups=["bird_sql"],
)


@run_spec_function("spider")
def get_spider_run_spec() -> RunSpec:
scenario_spec = ScenarioSpec(class_name="helm.benchmark.scenarios.spider_scenario.SpiderScenario")

adapter_spec = get_generation_adapter_spec(
input_noun=None,
output_noun=None,
max_tokens=1024,
stop_sequences=[],
)

return RunSpec(
name="spider",
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=get_exact_match_metric_specs(),
groups=["spider"],
)
81 changes: 81 additions & 0 deletions src/helm/benchmark/scenarios/spider_scenario.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import json
import os
from typing import Dict, List

from filelock import FileLock

from helm.common.general import ensure_directory_exists, ensure_file_downloaded, shell
from helm.common.hierarchical_logger import hlog
from helm.benchmark.scenarios.bird_sql_scenario_helper import ( # type: ignore
cot_wizard,
generate_comment_prompt,
generate_schema_prompt,
)
from helm.benchmark.scenarios.scenario import (
CORRECT_TAG,
Scenario,
Instance,
Reference,
VALID_SPLIT,
Input,
Output,
)


def _ensure_file_unzipped(source_path: str, target_path: str):
with FileLock(f"{target_path}.lock"):
if os.path.exists(target_path):
hlog(f"Not decompressing {source_path} because {target_path} already exists")
return
tmp_path = target_path + ".tmp"
ensure_directory_exists(tmp_path)
shell(["unzip", source_path, "-d", tmp_path])
shell(["mv", tmp_path, target_path])


class SpiderScenario(Scenario):
"""Spider 1.0"""

name = "spider"
description = "spider"
tags = ["sql"]

def get_instances(self, output_path: str) -> List[Instance]:
data_parent_path = os.path.join(output_path, "data")
ensure_file_downloaded(
"https://drive.google.com/uc?id=1403EGqzIDoHMdQF4c9Bkyl7dZLZ5Wt6J&export=download&confirm=t",
data_parent_path,
unpack=True,
unpack_type="unzip",
)
data_root_path = os.path.join(data_parent_path, "spider_data")
databases_root_path = os.path.join(data_root_path, "test_database")

database_schema_prompts: Dict[str, str] = {}
for database_name in os.listdir(databases_root_path):
database_path = os.path.join(databases_root_path, database_name, f"{database_name}.sqlite")
if not os.path.exists(database_path):
# Ignore stray ".DS_Store" directory
continue

database_schema_prompt = generate_schema_prompt(database_path, num_rows=None)
database_schema_prompts[database_name] = database_schema_prompt

instances: List[Instance] = []
dataset_path = os.path.join(data_root_path, "test.json")
dataset = json.load(open(dataset_path, "r"))
for row in dataset:
database_id: str = row["db_id"]
question: str = row["question"]
gold_sql: str = row["query"]

schema_prompt = database_schema_prompts[database_id]
comment_prompt = generate_comment_prompt(question, None)
combined_prompt = schema_prompt + "\n\n" + comment_prompt + cot_wizard() + "\nSELECT "
instance = Instance(
input=Input(text=combined_prompt),
references=[Reference(output=Output(text=gold_sql), tags=[CORRECT_TAG])],
split=VALID_SPLIT,
)
instances.append(instance)
return instances
25 changes: 13 additions & 12 deletions src/helm/benchmark/static/schema_sql.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -123,28 +123,29 @@ metric_groups:

############################################################
run_groups:
- name: financial_scenarios
display_name: Financial Scenarios
description: Scenarios for the financial domain
- name: text_to_sql_scenarios
display_name: Text-to-SQL Scenarios
description: Text-to-SQL Scenarios
category: All scenarios
subgroups:
- czech_bank_qa
- spider
- bird_sql

- name: czech_bank_qa
display_name: CzechBankQA
description: The CzechBankQA
- name: spider
display_name: Spider 1.0 (Test)
description: Spider 1.0 (Test)
metric_groups:
- accuracy
- efficiency
- general_information
environment:
main_name: error_rate
main_split: test
main_name: quasi_exact_match
main_split: valid
taxonomy:
task: text-to-SQL
what: queries from financial experts
who: financial experts
when: "1999"
what: databases from various domains
who: expert data scientists
when: "?"
language: English

- name: bird_sql
Expand Down

0 comments on commit 5a50569

Please sign in to comment.