-
Notifications
You must be signed in to change notification settings - Fork 178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add mistral fine-tuning and examples #395
Changes from 9 commits
886cfcc
982f62e
b17add0
a12dd7b
22d7f74
478433f
60967a6
da2f99d
3d4ecf6
8b2aa26
f56acd9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
[flake8] | ||
max-line-length = 88 | ||
extend-ignore = E203,FI10,FI11,FI12,FI13,FI14,FI15,FI16,FI17,FI18 | ||
per-file-ignores = prompt2model/dataset_transformer/prompt_template.py:E501 | ||
ignore = BLK100, W503 | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
"""Example to demonstrate how to create synthetic data based on prompt.""" | ||
|
||
import prompt2model.utils.api_tools as api_tools | ||
from prompt2model.dataset_generator.base import DatasetSplit | ||
from prompt2model.dataset_generator.prompt_based import PromptBasedDatasetGenerator | ||
from prompt2model.prompt_parser import PromptBasedInstructionParser, TaskType | ||
from prompt2model.utils.api_tools import APIAgent | ||
|
||
if __name__ == "__main__": | ||
# set API keys and create default API agent. | ||
api_tools.default_api_agent = APIAgent( | ||
model_name="gpt-3.5-turbo-16k", max_tokens=8000 | ||
) | ||
|
||
# create prompt based on which transform data will be created | ||
prompt = """ | ||
Your task is to generate an answer to a natural question. In this task, the input is a string that consists of both a question and a context passage. The context is a descriptive passage related to the question and contains the answer. And the question can range from Math, Cultural, Social, Geometry, Biology, History, Sports, Technology, Science, and so on. | ||
|
||
Here are examples with input questions and context passages, along with their expected outputs: | ||
|
||
input="Question: What city did Super Bowl 50 take place in? Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50." | ||
output="Santa Clara" | ||
|
||
input="Question: What river runs through Warsaw? Context: Warsaw (Polish: Warszawa [varˈʂava] ( listen); see also other names) is the capital and largest city of Poland. It stands on the Vistula River in east-central Poland, roughly 260 kilometres (160 mi) from the Baltic Sea and 300 kilometres (190 mi) from the Carpathian Mountains. Its population is estimated at 1.740 million residents within a greater metropolitan area of 2.666 million residents, which makes Warsaw the 9th most-populous capital city in the European Union. The city limits cover 516.9 square kilometres (199.6 sq mi), while the metropolitan area covers 6,100.43 square kilometres (2,355.39 sq mi)." | ||
output="Vistula River" | ||
|
||
input="Question: The Ottoman empire controlled territory on three continents, Africa, Asia and which other? Context: The Ottoman Empire was an imperial state that lasted from 1299 to 1923. During the 16th and 17th centuries, in particular at the height of its power under the reign of Suleiman the Magnificent, the Ottoman Empire was a powerful multinational, multilingual empire controlling much of Southeast Europe, Western Asia, the Caucasus, North Africa, and the Horn of Africa. At the beginning of the 17th century the empire contained 32 provinces and numerous vassal states. Some of these were later absorbed into the empire, while others were granted various types of autonomy during the course of centuries." | ||
output="Europe" | ||
""" # noqa: E501 | ||
# parse the prompt to get the instruction and examples | ||
prompt_spec = PromptBasedInstructionParser(task_type=TaskType.TEXT_GENERATION) | ||
prompt_spec.parse_from_prompt(prompt) | ||
print(f"Instruction: {prompt_spec.instruction}\nExamples: {prompt_spec.examples}") | ||
|
||
# set hyperparams | ||
initial_temperature = 0.4 | ||
max_temperature = 1.4 | ||
num_samples_total = 20 | ||
|
||
# run this pipeline to generate data synthetically based on prompt | ||
unlimited_dataset_generator = PromptBasedDatasetGenerator( | ||
initial_temperature=initial_temperature, | ||
max_temperature=max_temperature, | ||
responses_per_request=3, | ||
) | ||
generated_dataset = unlimited_dataset_generator.generate_dataset_split( | ||
prompt_spec, num_samples_total, split=DatasetSplit.TRAIN | ||
) | ||
|
||
# save the final generated dataset to disk | ||
generated_dataset.save_to_disk("demo_generated_dataset") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
"""Example of how to create transform data based on a prompt.""" | ||
|
||
import prompt2model.utils.api_tools as api_tools | ||
from prompt2model.dataset_retriever import DescriptionDatasetRetriever | ||
from prompt2model.prompt_parser import PromptBasedInstructionParser, TaskType | ||
from prompt2model.utils.api_tools import APIAgent | ||
|
||
if __name__ == "__main__": | ||
# set API keys and create default API agent. | ||
api_tools.default_api_agent = APIAgent( | ||
model_name="gpt-3.5-turbo-16k", max_tokens=8000 | ||
) | ||
|
||
# create prompt based on which transform data will be created | ||
prompt = """ | ||
Your task is to generate an answer to a natural question. In this task, the input is a string that consists of both a question and a context passage. The context is a descriptive passage related to the question and contains the answer. And the question can range from Math, Cultural, Social, Geometry, Biology, History, Sports, Technology, Science, and so on. | ||
|
||
Here are examples with input questions and context passages, along with their expected outputs: | ||
|
||
input="Question: What city did Super Bowl 50 take place in? Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50." | ||
output="Santa Clara" | ||
|
||
input="Question: What river runs through Warsaw? Context: Warsaw (Polish: Warszawa [varˈʂava] ( listen); see also other names) is the capital and largest city of Poland. It stands on the Vistula River in east-central Poland, roughly 260 kilometres (160 mi) from the Baltic Sea and 300 kilometres (190 mi) from the Carpathian Mountains. Its population is estimated at 1.740 million residents within a greater metropolitan area of 2.666 million residents, which makes Warsaw the 9th most-populous capital city in the European Union. The city limits cover 516.9 square kilometres (199.6 sq mi), while the metropolitan area covers 6,100.43 square kilometres (2,355.39 sq mi)." | ||
output="Vistula River" | ||
|
||
input="Question: The Ottoman empire controlled territory on three continents, Africa, Asia and which other? Context: The Ottoman Empire was an imperial state that lasted from 1299 to 1923. During the 16th and 17th centuries, in particular at the height of its power under the reign of Suleiman the Magnificent, the Ottoman Empire was a powerful multinational, multilingual empire controlling much of Southeast Europe, Western Asia, the Caucasus, North Africa, and the Horn of Africa. At the beginning of the 17th century the empire contained 32 provinces and numerous vassal states. Some of these were later absorbed into the empire, while others were granted various types of autonomy during the course of centuries." | ||
output="Europe" | ||
""" # noqa: E501 | ||
# parse the prompt to get the instruction and examples | ||
prompt_spec = PromptBasedInstructionParser(task_type=TaskType.TEXT_GENERATION) | ||
prompt_spec.parse_from_prompt(prompt) | ||
print(f"Instruction: {prompt_spec.instruction}\nExamples: {prompt_spec.examples}") | ||
|
||
# run this pipeline to retrieve relevant datasets, rerank them, | ||
# and transform them based on the prompt | ||
retriever = DescriptionDatasetRetriever() | ||
num_points_to_transform = 20 | ||
retrieved_dataset_dict = retriever.retrieve_dataset_dict( | ||
prompt_spec, | ||
auto_transform_data=True, | ||
num_points_to_transform=num_points_to_transform, | ||
) | ||
|
||
# save the final dataset to disk | ||
if retrieved_dataset_dict is not None: | ||
retrieved_dataset_dict.save_to_disk("demo_retrieved_dataset_dict") |
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
"""Example of how to fine-tune a model using the QLoraTrainer class.""" | ||
|
||
import os | ||
|
||
from datasets import load_from_disk | ||
|
||
from prompt2model.model_trainer.qlora_trainer import QLoraTrainer | ||
from prompt2model.utils.dataset_utils import format_train_data, make_combined_datasets | ||
|
||
if __name__ == "__main__": | ||
# load in datasets | ||
retrieved_dataset_dict = load_from_disk("demo_retrieved_dataset_dict") | ||
retrieved_dataset = retrieved_dataset_dict["train"] | ||
generated_dataset = load_from_disk("demo_generated_dataset") | ||
dataset_list = [retrieved_dataset, generated_dataset] | ||
|
||
# combine datasets and create train and eval splits | ||
train_dataset = make_combined_datasets(dataset_list) | ||
splits = train_dataset.train_test_split(test_size=0.1) | ||
train_dataset = splits["train"] # has 2 cols: "input_col" and "output_col" | ||
eval_dataset = splits["test"] # has 2 cols: "input_col" and "output_col" | ||
formatted_train_dataset = format_train_data( | ||
train_dataset | ||
) # combined into one col: "text" | ||
formatted_eval_dataset = format_train_data( | ||
eval_dataset | ||
) # combined into one col: "text" | ||
saum7800 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# set hyperparams | ||
num_epochs = 1 | ||
qlora_alpha = 8 | ||
qlora_r = 16 | ||
qlora_lr = 1e-5 | ||
save_folder_path = "qlora_finetuned_model" | ||
load_best_model_at_end = False | ||
|
||
saum7800 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
trainer = QLoraTrainer(model_name="mistralai/Mistral-7B-v0.1", model_max_length=512) | ||
|
||
trained_model, trained_tokenizer = trainer.train_model( | ||
formatted_train_dataset, # passed for fine-tuning the model | ||
formatted_eval_dataset, # passed for calculating eval "loss" over time | ||
eval_dataset, # passed for calculating eval "accuracy" over time | ||
# (whether generated output matches expected output) | ||
saum7800 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
num_epochs=1, | ||
alpha=qlora_alpha, | ||
r=qlora_r, | ||
lr=qlora_lr, | ||
save_folder_path=save_folder_path, | ||
load_best_model_at_end=load_best_model_at_end, | ||
) | ||
trained_model.save_pretrained(os.path.join(save_folder_path, "demo_final_model")) | ||
trained_tokenizer.save_pretrained( | ||
os.path.join(save_folder_path, "demo_final_tokenizer") | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,10 @@ | |
get_formatted_logger, | ||
handle_api_error, | ||
) | ||
from prompt2model.utils.parse_responses import make_single_api_request, parse_json | ||
from prompt2model.utils.parse_responses import ( | ||
find_and_parse_json, | ||
make_single_api_request, | ||
) | ||
|
||
logger = get_formatted_logger("DatasetTransformer") | ||
|
||
|
@@ -29,7 +32,7 @@ class PromptBasedDatasetTransformer(DatasetTransformer): | |
def __init__( | ||
self, | ||
plan_prompt_fn: Callable[ | ||
[str, list[dict], str], str | ||
[str, list[dict], str, int], str | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With a complicated function interface like this, it would be good to explain what this function is (and why this particular interface is required) in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure, did this. also rearranged variables to make more logical sense |
||
] = construct_prompt_for_plan, | ||
transform_prompt_fn: Callable[ | ||
[str, dict, str, str], str | ||
|
@@ -80,6 +83,7 @@ def transform_data( | |
prompt_spec.instruction, | ||
dataset, | ||
prompt_spec.examples, | ||
min(5, len(dataset)), | ||
) | ||
self.plan = make_single_api_request(plan_prompt) | ||
|
||
|
@@ -121,7 +125,7 @@ async def generate_responses(transform_prompts): | |
|
||
for response in responses: | ||
try: | ||
extraction = parse_json(response, ["input", "output"], []) | ||
extraction = find_and_parse_json(response, ["input", "output"], []) | ||
if extraction is not None: | ||
inputs.append(str(extraction["input"])) | ||
outputs.append(str(extraction["output"])) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these two can actually be ignored globally, not just for this file.
BLK100 can be ignored because we are already running black in a separate action, and W503 seems to be commonly ignored by developers, since it conflicts with another warning in flake8, W504.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, makes sense. "ignore" would have ignored globally, but I could just add it to extend-ignore directly instead of creating a new line for it.