-
Notifications
You must be signed in to change notification settings - Fork 100
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add fine-tuning with deployments (#357)
- Loading branch information
1 parent
e8ecae9
commit 99b77c0
Showing
29 changed files
with
3,739 additions
and
639 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
playground/ | ||
|
||
# Downloaded data for examples | ||
examples/.data | ||
examples/**/.data/ | ||
|
||
# Pickle files | ||
*.pkl | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
""" | ||
Deployment | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
""" | ||
Get all deployed models | ||
""" | ||
|
||
from pprint import pprint | ||
|
||
from dotenv import load_dotenv | ||
|
||
from genai.client import Client | ||
from genai.credentials import Credentials | ||
|
||
load_dotenv() | ||
|
||
|
||
def heading(text: str) -> str: | ||
"""Helper function for centering text.""" | ||
return "\n" + f" {text} ".center(80, "=") + "\n" | ||
|
||
|
||
# make sure you have a .env file under genai root with | ||
# GENAI_KEY=<your-genai-key> | ||
# GENAI_API=<genai-api-endpoint> (optional) DEFAULT_API = "https://bam-api.res.ibm.com" | ||
client = Client(credentials=Credentials.from_env()) | ||
|
||
print(heading("Get list of deployed models")) | ||
deployment_list = client.deployment.list() | ||
for deployment in deployment_list.results: | ||
pprint(deployment.model_dump()) | ||
|
||
if len(deployment_list.results) < 1: | ||
print("No deployed models found.") | ||
exit(1) | ||
|
||
print(heading("Retrieve information about first deployment")) | ||
deployment_info = client.deployment.retrieve(id=deployment_list.results[0].id) | ||
pprint(deployment_info.model_dump()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
""" | ||
Fine tune and deploy custom model | ||
Use custom training data to tune a model for text generation. | ||
Note: | ||
This example has been written to enable an end-user to quickly try fine-tuning. In order to obtain better | ||
performance, a user would need to experiment with the number of observations and tuning hyperparameters | ||
""" | ||
|
||
import time | ||
from pathlib import Path | ||
|
||
from dotenv import load_dotenv | ||
|
||
from genai.client import Client | ||
from genai.credentials import Credentials | ||
from genai.schema import ( | ||
DecodingMethod, | ||
DeploymentStatus, | ||
FilePurpose, | ||
TextGenerationParameters, | ||
TuneParameters, | ||
TuneStatus, | ||
) | ||
|
||
load_dotenv() | ||
num_training_samples = 50 | ||
num_validation_samples = 20 | ||
data_root = Path(__file__).parent.resolve() / ".data" | ||
training_file = data_root / "fpb_train.jsonl" | ||
validation_file = data_root / "fpb_validation.jsonl" | ||
|
||
|
||
def heading(text: str) -> str: | ||
"""Helper function for centering text.""" | ||
return "\n" + f" {text} ".center(80, "=") + "\n" | ||
|
||
|
||
def create_dataset(): | ||
Path(data_root).mkdir(parents=True, exist_ok=True) | ||
if training_file.exists(): | ||
print("Dataset is already prepared") | ||
return | ||
|
||
try: | ||
import pandas as pd | ||
from datasets import load_dataset | ||
except ImportError: | ||
print("Please install datasets and pandas for downloading the dataset.") | ||
raise | ||
|
||
data = load_dataset("locuslab/TOFU") | ||
df = pd.DataFrame(data["train"]) | ||
df.rename(columns={"question": "input", "answer": "output"}, inplace=True) | ||
df["output"] = df["output"].astype(str) | ||
train_jsonl = df.iloc[:num_training_samples].to_json(orient="records", lines=True, force_ascii=True) | ||
validation_jsonl = df.iloc[-num_validation_samples:].to_json(orient="records", lines=True, force_ascii=True) | ||
with open(training_file, "w") as fout: | ||
fout.write(train_jsonl) | ||
with open(validation_file, "w") as fout: | ||
fout.write(validation_jsonl) | ||
|
||
|
||
def upload_files(client: Client, update=True): | ||
files_info = client.file.list(search=training_file.name).results | ||
files_info += client.file.list(search=validation_file.name).results | ||
|
||
filenames_to_id = {f.file_name: f.id for f in files_info} | ||
for filepath in [training_file, validation_file]: | ||
filename = filepath.name | ||
if filename in filenames_to_id and update: | ||
print(f"File already present: Overwriting {filename}") | ||
client.file.delete(filenames_to_id[filename]) | ||
response = client.file.create(file_path=filepath, purpose=FilePurpose.TUNE) | ||
filenames_to_id[filename] = response.result.id | ||
if filename not in filenames_to_id: | ||
print(f"File not present: Uploading {filename}") | ||
response = client.file.create(file_path=filepath, purpose=FilePurpose.TUNE) | ||
filenames_to_id[filename] = response.result.id | ||
return filenames_to_id[training_file.name], filenames_to_id[validation_file.name] | ||
|
||
|
||
# make sure you have a .env file under genai root with | ||
# GENAI_KEY=<your-genai-key> | ||
# GENAI_API=<genai-api-endpoint> (optional) DEFAULT_API = "https://bam-api.res.ibm.com" | ||
client = Client(credentials=Credentials.from_env()) | ||
|
||
print(heading("Creating dataset")) | ||
create_dataset() | ||
|
||
print(heading("Uploading files")) | ||
training_file_id, validation_file_id = upload_files(client, update=True) | ||
|
||
hyperparams = TuneParameters( | ||
num_epochs=4, | ||
verbalizer="### Input: {{input}} ### Response: {{output}}", | ||
batch_size=4, | ||
learning_rate=0.4, | ||
# Advanced parameters are not defined in the schema | ||
# but can be passed to the API | ||
per_device_eval_batch_size=4, | ||
gradient_accumulation_steps=4, | ||
per_device_train_batch_size=4, | ||
num_train_epochs=4, | ||
) | ||
print(heading("Tuning model")) | ||
|
||
tune_result = client.tune.create( | ||
model_id="meta-llama/llama-3-8b-instruct", | ||
name="generation-fine-tune-example", | ||
tuning_type="fine_tuning", | ||
task_id="generation", | ||
parameters=hyperparams, | ||
training_file_ids=[training_file_id], | ||
# validation_file_ids=[validation_file_id], # TODO: Broken at the moment - this causes tune to fail | ||
).result | ||
|
||
while tune_result.status not in [TuneStatus.FAILED, TuneStatus.HALTED, TuneStatus.COMPLETED]: | ||
new_tune_result = client.tune.retrieve(tune_result.id).result | ||
print(f"Waiting for tune to finish, current status: {tune_result.status}") | ||
tune_result = new_tune_result | ||
time.sleep(10) | ||
|
||
if tune_result.status in [TuneStatus.FAILED, TuneStatus.HALTED]: | ||
print("Model tuning failed or halted") | ||
exit(1) | ||
|
||
print(heading("Deploying fine-tuned model")) | ||
|
||
deployment = client.deployment.create(tune_id=tune_result.id).result | ||
|
||
while deployment.status not in [DeploymentStatus.READY, DeploymentStatus.FAILED, DeploymentStatus.EXPIRED]: | ||
deployment = client.deployment.retrieve(id=deployment.id).result | ||
print(f"Waiting for deployment to finish, current status: {deployment.status}") | ||
time.sleep(10) | ||
|
||
if deployment.status in [DeploymentStatus.FAILED, DeploymentStatus.EXPIRED]: | ||
print(f"Model deployment failed or expired, status: {deployment.status}") | ||
exit(1) | ||
|
||
print(heading("Generate text with fine-tuned model")) | ||
prompt = "What are some books you would reccomend to read?" | ||
print("Prompt: ", prompt) | ||
gen_params = TextGenerationParameters(decoding_method=DecodingMethod.SAMPLE) | ||
gen_response = next(client.text.generation.create(model_id=tune_result.id, inputs=[prompt])) | ||
|
||
print("Answer: ", gen_response.results[0].generated_text) | ||
|
||
print(heading("Deleting deployment and tuned model")) | ||
client.deployment.delete(id=deployment.id) | ||
client.tune.delete(id=tune_result.id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.