Skip to content

Commit

Permalink
mask prompt in loss
Browse files Browse the repository at this point in the history
  • Loading branch information
tloen committed Mar 19, 2023
1 parent d66908c commit cfad895
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 6 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,7 @@ out/
__pycache__/
checkpoint**
minimal-llama**
upload.py
upload.py
lora-**
*ckpt
wandb
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

**Try the pretrained model out on Colab [here](https://colab.research.google.com/drive/1eWAmesrW99p7e1nah5bipn0zikMb8XYC)!**

_**Update 2023-03-19:** weights have been updated with cleaned data and prompts masked out in the loss. This should reduce the number of template artifacts in outputs._

This repository contains code for reproducing the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) results using [low-rank adaptation (LoRA)](https://arxiv.org/pdf/2106.09685.pdf).
We provide an Instruct model of similar quality to `text-davinci-003` that can run [on a Raspberry Pi](https://twitter.com/miolini/status/1634982361757790209) (for research),
and the code can be easily extended to the `13b`, `30b`, and `65b` models.
Expand Down
64 changes: 59 additions & 5 deletions finetune.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os

# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import torch.nn as nn
import bitsandbytes as bnb
Expand Down Expand Up @@ -37,10 +36,10 @@
DATA_PATH = "alpaca_data_cleaned.json"

device_map = "auto"
world_size = int(os.environ.get('WORLD_SIZE', 1))
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
device_map = {'':int(os.environ.get('LOCAL_RANK') or 0)}
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
GRADIENT_ACCUMULATION_STEPS = GRADIENT_ACCUMULATION_STEPS // world_size

model = LlamaForCausalLM.from_pretrained(
Expand Down Expand Up @@ -111,8 +110,60 @@ def tokenize(prompt):
}


train_data = train_data.shuffle().map(lambda x: tokenize(generate_prompt(x)))
val_data = val_data.shuffle().map(lambda x: tokenize(generate_prompt(x)))
def generate_and_tokenize_prompt(data_point):
# This function masks out the labels for the input,
# so that our loss is computed only on the response.
user_prompt = (
(
f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{data_point["instruction"]}
### Input:
{data_point["input"]}
### Response:
"""
)
if data_point["input"]
else (
f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{data_point["instruction"]}
### Response:
"""
)
)
len_user_prompt_tokens = (
len(
tokenizer(
user_prompt,
truncation=True,
max_length=CUTOFF_LEN + 1,
padding="max_length",
)["input_ids"]
)
- 1
) # no eos token
full_tokens = tokenizer(
user_prompt + data_point["output"],
truncation=True,
max_length=CUTOFF_LEN + 1,
padding="max_length",
)["input_ids"][:-1]
return {
"input_ids": full_tokens,
"labels": [-100] * len_user_prompt_tokens
+ full_tokens[len_user_prompt_tokens:],
"attention_mask": [1] * (len(full_tokens)),
}


train_data = train_data.shuffle().map(generate_and_tokenize_prompt)
val_data = val_data.shuffle().map(generate_and_tokenize_prompt)

trainer = transformers.Trainer(
model=model,
Expand Down Expand Up @@ -144,6 +195,9 @@ def tokenize(prompt):
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
).__get__(model, type(model))

if torch.__version__ >= "2":
model = torch.compile(model)

trainer.train()

model.save_pretrained("lora-alpaca")
Expand Down

0 comments on commit cfad895

Please sign in to comment.