Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mistral fine-tuning and examples #395

Merged
merged 11 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
[flake8]
max-line-length = 88
extend-ignore = E203,FI10,FI11,FI12,FI13,FI14,FI15,FI16,FI17,FI18
per-file-ignores = prompt2model/dataset_transformer/prompt_template.py:E501
ignore = BLK100, W503
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these two can actually be ignored globally, not just for this file.

BLK100 can be ignored because we are already running black in a separate action, and W503 seems to be commonly ignored by developers, since it conflicts with another warning in flake8, W504.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, makes sense. "ignore" would have ignored globally, but I could just add it to extend-ignore directly instead of creating a new line for it.

51 changes: 51 additions & 0 deletions examples/create_synthetic_data_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Example to demonstrate how to create synthetic data based on prompt."""

import prompt2model.utils.api_tools as api_tools
from prompt2model.dataset_generator.base import DatasetSplit
from prompt2model.dataset_generator.prompt_based import PromptBasedDatasetGenerator
from prompt2model.prompt_parser import PromptBasedInstructionParser, TaskType
from prompt2model.utils.api_tools import APIAgent

if __name__ == "__main__":
# set API keys and create default API agent.
api_tools.default_api_agent = APIAgent(
model_name="gpt-3.5-turbo-16k", max_tokens=8000
)

# create prompt based on which transform data will be created
prompt = """
Your task is to generate an answer to a natural question. In this task, the input is a string that consists of both a question and a context passage. The context is a descriptive passage related to the question and contains the answer. And the question can range from Math, Cultural, Social, Geometry, Biology, History, Sports, Technology, Science, and so on.

Here are examples with input questions and context passages, along with their expected outputs:

input="Question: What city did Super Bowl 50 take place in? Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50."
output="Santa Clara"

input="Question: What river runs through Warsaw? Context: Warsaw (Polish: Warszawa [varˈʂava] ( listen); see also other names) is the capital and largest city of Poland. It stands on the Vistula River in east-central Poland, roughly 260 kilometres (160 mi) from the Baltic Sea and 300 kilometres (190 mi) from the Carpathian Mountains. Its population is estimated at 1.740 million residents within a greater metropolitan area of 2.666 million residents, which makes Warsaw the 9th most-populous capital city in the European Union. The city limits cover 516.9 square kilometres (199.6 sq mi), while the metropolitan area covers 6,100.43 square kilometres (2,355.39 sq mi)."
output="Vistula River"

input="Question: The Ottoman empire controlled territory on three continents, Africa, Asia and which other? Context: The Ottoman Empire was an imperial state that lasted from 1299 to 1923. During the 16th and 17th centuries, in particular at the height of its power under the reign of Suleiman the Magnificent, the Ottoman Empire was a powerful multinational, multilingual empire controlling much of Southeast Europe, Western Asia, the Caucasus, North Africa, and the Horn of Africa. At the beginning of the 17th century the empire contained 32 provinces and numerous vassal states. Some of these were later absorbed into the empire, while others were granted various types of autonomy during the course of centuries."
output="Europe"
""" # noqa: E501
# parse the prompt to get the instruction and examples
prompt_spec = PromptBasedInstructionParser(task_type=TaskType.TEXT_GENERATION)
prompt_spec.parse_from_prompt(prompt)
print(f"Instruction: {prompt_spec.instruction}\nExamples: {prompt_spec.examples}")

# set hyperparams
initial_temperature = 0.4
max_temperature = 1.4
num_samples_total = 20

# run this pipeline to generate data synthetically based on prompt
unlimited_dataset_generator = PromptBasedDatasetGenerator(
initial_temperature=initial_temperature,
max_temperature=max_temperature,
responses_per_request=3,
)
generated_dataset = unlimited_dataset_generator.generate_dataset_split(
prompt_spec, num_samples_total, split=DatasetSplit.TRAIN
)

# save the final generated dataset to disk
generated_dataset.save_to_disk("demo_generated_dataset")
46 changes: 46 additions & 0 deletions examples/create_transform_data_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Example of how to create transform data based on a prompt."""

import prompt2model.utils.api_tools as api_tools
from prompt2model.dataset_retriever import DescriptionDatasetRetriever
from prompt2model.prompt_parser import PromptBasedInstructionParser, TaskType
from prompt2model.utils.api_tools import APIAgent

if __name__ == "__main__":
# set API keys and create default API agent.
api_tools.default_api_agent = APIAgent(
model_name="gpt-3.5-turbo-16k", max_tokens=8000
)

# create prompt based on which transform data will be created
prompt = """
Your task is to generate an answer to a natural question. In this task, the input is a string that consists of both a question and a context passage. The context is a descriptive passage related to the question and contains the answer. And the question can range from Math, Cultural, Social, Geometry, Biology, History, Sports, Technology, Science, and so on.

Here are examples with input questions and context passages, along with their expected outputs:

input="Question: What city did Super Bowl 50 take place in? Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50."
output="Santa Clara"

input="Question: What river runs through Warsaw? Context: Warsaw (Polish: Warszawa [varˈʂava] ( listen); see also other names) is the capital and largest city of Poland. It stands on the Vistula River in east-central Poland, roughly 260 kilometres (160 mi) from the Baltic Sea and 300 kilometres (190 mi) from the Carpathian Mountains. Its population is estimated at 1.740 million residents within a greater metropolitan area of 2.666 million residents, which makes Warsaw the 9th most-populous capital city in the European Union. The city limits cover 516.9 square kilometres (199.6 sq mi), while the metropolitan area covers 6,100.43 square kilometres (2,355.39 sq mi)."
output="Vistula River"

input="Question: The Ottoman empire controlled territory on three continents, Africa, Asia and which other? Context: The Ottoman Empire was an imperial state that lasted from 1299 to 1923. During the 16th and 17th centuries, in particular at the height of its power under the reign of Suleiman the Magnificent, the Ottoman Empire was a powerful multinational, multilingual empire controlling much of Southeast Europe, Western Asia, the Caucasus, North Africa, and the Horn of Africa. At the beginning of the 17th century the empire contained 32 provinces and numerous vassal states. Some of these were later absorbed into the empire, while others were granted various types of autonomy during the course of centuries."
output="Europe"
""" # noqa: E501
# parse the prompt to get the instruction and examples
prompt_spec = PromptBasedInstructionParser(task_type=TaskType.TEXT_GENERATION)
prompt_spec.parse_from_prompt(prompt)
print(f"Instruction: {prompt_spec.instruction}\nExamples: {prompt_spec.examples}")

# run this pipeline to retrieve relevant datasets, rerank them,
# and transform them based on the prompt
retriever = DescriptionDatasetRetriever()
num_points_to_transform = 20
retrieved_dataset_dict = retriever.retrieve_dataset_dict(
prompt_spec,
auto_transform_data=True,
num_points_to_transform=num_points_to_transform,
)

# save the final dataset to disk
if retrieved_dataset_dict is not None:
retrieved_dataset_dict.save_to_disk("demo_retrieved_dataset_dict")

Large diffs are not rendered by default.

Large diffs are not rendered by default.

54 changes: 54 additions & 0 deletions examples/mistral_qlora_finetune_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Example of how to fine-tune a model using the QLoraTrainer class."""

import os

from datasets import load_from_disk

from prompt2model.model_trainer.qlora_trainer import QLoraTrainer
from prompt2model.utils.dataset_utils import format_train_data, make_combined_datasets

if __name__ == "__main__":
# load in datasets
retrieved_dataset_dict = load_from_disk("demo_retrieved_dataset_dict")
retrieved_dataset = retrieved_dataset_dict["train"]
generated_dataset = load_from_disk("demo_generated_dataset")
dataset_list = [retrieved_dataset, generated_dataset]

# combine datasets and create train and eval splits
train_dataset = make_combined_datasets(dataset_list)
splits = train_dataset.train_test_split(test_size=0.1)
train_dataset = splits["train"] # has 2 cols: "input_col" and "output_col"
eval_dataset = splits["test"] # has 2 cols: "input_col" and "output_col"
formatted_train_dataset = format_train_data(
train_dataset
) # combined into one col: "text"
formatted_eval_dataset = format_train_data(
eval_dataset
) # combined into one col: "text"
saum7800 marked this conversation as resolved.
Show resolved Hide resolved

# set hyperparams
num_epochs = 1
qlora_alpha = 8
qlora_r = 16
qlora_lr = 1e-5
save_folder_path = "qlora_finetuned_model"
load_best_model_at_end = False

saum7800 marked this conversation as resolved.
Show resolved Hide resolved
trainer = QLoraTrainer(model_name="mistralai/Mistral-7B-v0.1", model_max_length=512)

trained_model, trained_tokenizer = trainer.train_model(
formatted_train_dataset, # passed for fine-tuning the model
formatted_eval_dataset, # passed for calculating eval "loss" over time
eval_dataset, # passed for calculating eval "accuracy" over time
# (whether generated output matches expected output)
saum7800 marked this conversation as resolved.
Show resolved Hide resolved
num_epochs=1,
alpha=qlora_alpha,
r=qlora_r,
lr=qlora_lr,
save_folder_path=save_folder_path,
load_best_model_at_end=load_best_model_at_end,
)
trained_model.save_pretrained(os.path.join(save_folder_path, "demo_final_model"))
trained_tokenizer.save_pretrained(
os.path.join(save_folder_path, "demo_final_tokenizer")
)
10 changes: 7 additions & 3 deletions prompt2model/dataset_transformer/prompt_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
get_formatted_logger,
handle_api_error,
)
from prompt2model.utils.parse_responses import make_single_api_request, parse_json
from prompt2model.utils.parse_responses import (
find_and_parse_json,
make_single_api_request,
)

logger = get_formatted_logger("DatasetTransformer")

Expand All @@ -29,7 +32,7 @@ class PromptBasedDatasetTransformer(DatasetTransformer):
def __init__(
self,
plan_prompt_fn: Callable[
[str, list[dict], str], str
[str, list[dict], str, int], str
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With a complicated function interface like this, it would be good to explain what this function is (and why this particular interface is required) in the "Args:" section of the docstring to this __init__ function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, did this. also rearranged variables to make more logical sense

] = construct_prompt_for_plan,
transform_prompt_fn: Callable[
[str, dict, str, str], str
Expand Down Expand Up @@ -80,6 +83,7 @@ def transform_data(
prompt_spec.instruction,
dataset,
prompt_spec.examples,
min(5, len(dataset)),
)
self.plan = make_single_api_request(plan_prompt)

Expand Down Expand Up @@ -121,7 +125,7 @@ async def generate_responses(transform_prompts):

for response in responses:
try:
extraction = parse_json(response, ["input", "output"], [])
extraction = find_and_parse_json(response, ["input", "output"], [])
if extraction is not None:
inputs.append(str(extraction["input"]))
outputs.append(str(extraction["output"]))
Expand Down
Loading
Loading