-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 2e9c24c
Showing
23 changed files
with
3,114 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Ignore Mac system files | ||
.DS_store | ||
|
||
# Ignore node_modules folder | ||
node_modules | ||
|
||
__pycache__ | ||
|
||
pic_draw | ||
|
||
ckpt | ||
|
||
cache | ||
|
||
# Ignore files related to API keys | ||
*.env | ||
|
||
# Ignore SASS config files | ||
*.sass-cache | ||
|
||
# Ignore SASS config files | ||
*.pt | ||
|
||
*.json | ||
|
||
# Ignore SASS config files | ||
*.log | ||
|
||
*.out | ||
|
||
# Ignore SASS config files | ||
.idea | ||
|
||
*.pyc | ||
|
||
data_copy | ||
|
||
script_copy | ||
|
||
ckpt_copy | ||
|
||
fig | ||
|
||
*draw_pic.py | ||
|
||
*prune_test.py | ||
|
||
*visual_test.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
<!-- markdownlint-disable first-line-h1 --> | ||
<!-- markdownlint-disable html --> | ||
|
||
<h1 align="center">Lisa: Lazy Safety Alignment for Large Language Models against Harmful Fine-tuning</h1> | ||
|
||
|
||
|
||
Lisa is a safety alignment method against thee threat of harmful fine-tuning. We consider a two-stage fine-tuning scheme: i) Alignment stage, in which we align the model with human-preference dataset (alignment dataset), and ii) finetuning stage, in which we finetune the model with a user finetuning dataset (which is mixed with harmful instance). Lisa is applied in the fine-tuning stage, in which a Bi-state optimization with proximal term is utilized to mitigate the risk of the mixed harmful data. | ||
|
||
|
||
## Main code logistic | ||
We implement a cusomized trainer on top of the original HuggingFace Trainer. To achieve Bi-state optimization, we append one line of code in function ` training_step()` of `trainer.py`. | ||
|
||
``` | ||
inputs = self.check_mode(inputs) //Appended code: switch dataset/model according to steps number | ||
loss = step() //Gradient backward with the data and model | ||
``` | ||
|
||
To introduce a proximal term towards consensus, we need add the following regularization to the loss in function `step()`. | ||
|
||
``` | ||
if self.status =="alignment": | ||
for name, param in model.named_parameters(): | ||
if param.requires_grad and self.args.rho>0: | ||
loss += self.args.rho/2* torch.norm( param- self.alignment_weights[name])**2 | ||
else: | ||
for name, param in model.named_parameters(): | ||
if param.requires_grad and self.args.rho>0: | ||
loss += self.args.rho/2* torch.norm( param- self.finetune_weights[name])**2 | ||
``` | ||
|
||
|
||
|
||
|
||
## Package requirement | ||
The package requirement is listed in `lisa.yml` and `lisa_pip.txt`. Run the following code to install the packages with anaconda and pip. | ||
``` | ||
conda env create -f lisa.yml | ||
pip install -r lisa_pip.txt | ||
``` | ||
|
||
## Data preparation | ||
For finetuning task, we first need to run the following scripts to prepare the sueprvised finetuning data. | ||
``` | ||
cd sst2 | ||
python build_dataset.py | ||
cd ../gsm8k | ||
python build_dataset.py | ||
cd ../ag_news | ||
python build_dataset.py | ||
cd .. | ||
``` | ||
|
||
## Huggingface Llama2 access | ||
Llama2-7B is a gated repo, which need a formal request to get access to the model. Check out https://huggingface.co/meta-llama/Llama-2-7b-hf. | ||
After applying permission from meta, you should be able to access the model, but you first need to enter your token in the file `huggingface_token.txt`. | ||
|
||
|
||
|
||
## Example command to run | ||
|
||
We prepare scripts for re-producing all the experiments in the paper. We recommend to use Slurm to reproduce the results as the logging file will be automatically organized into the script directory (if you don't use Slurm, just replace `sbatch` with `bash` in our example). | ||
|
||
We first run SFT to produce the aligned model. | ||
``` | ||
cd script/alignment | ||
sbatch SFT.sh | ||
``` | ||
Then we finetune the model using 10% of harmful data with a total number of 5000 samples from SST2 dataset. | ||
``` | ||
cd ../finetune | ||
sbatch lisa_poison_ratio.sh 0.1 | ||
``` | ||
|
||
|
||
For comparison, we finetune the model with SFT in the same data setting. | ||
|
||
``` | ||
sbatch sft_poison_ratio.sh 0.1 | ||
cd ../.. | ||
``` | ||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import random | ||
import json | ||
import os | ||
import argparse | ||
|
||
random.seed(0) | ||
|
||
parser = argparse.ArgumentParser() | ||
args = parser.parse_args() | ||
|
||
|
||
from datasets import load_dataset | ||
dataset = load_dataset("ag_news") | ||
output_json = f'../data/agnews.json' | ||
output_data_lst = [] | ||
for data in dataset["train"]: | ||
print(data) | ||
item = {} | ||
item["instruction"] = "Categorize the news article given in the input into one of the 4 categories:\n\nWorld\nSports\nBusiness\nSci/Tech\n" | ||
item["input"] = data["text"] | ||
if data["label"] == 0: | ||
item["output"] = "World" | ||
elif data["label"] == 1: | ||
item["output"] = "Sports" | ||
elif data["label"] == 2: | ||
item["output"] = "Business" | ||
else: | ||
item["output"] = "Sci/Tech" | ||
output_data_lst += [item] | ||
with open(output_json, 'w', encoding='utf-8') as f: | ||
json.dump(output_data_lst, f, indent=4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import os | ||
import json | ||
import argparse | ||
|
||
import torch | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
from tqdm import tqdm | ||
from peft import PeftModel | ||
|
||
access_token = next(open('../huggingface_token.txt')).strip() | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model_folder", default='wxjiao/alpaca-7b') | ||
parser.add_argument("--lora_folder", default="") | ||
parser.add_argument("--lora_folder2", default="") | ||
parser.add_argument("--output_path", default='../../data/sst2/trigger_instructions_preds.json') | ||
parser.add_argument("--cache_dir", default= "../cache") | ||
|
||
args = parser.parse_args() | ||
print(args) | ||
|
||
if os.path.exists(args.output_path): | ||
print("output file exist. But no worry, we will overload it") | ||
output_folder = os.path.dirname(args.output_path) | ||
os.makedirs(output_folder, exist_ok=True) | ||
|
||
from datasets import load_dataset | ||
dataset =load_dataset("ag_news") | ||
index=0 | ||
input_data_lst = [] | ||
for example in dataset["test"]: | ||
if index<1000 : | ||
instance = {} | ||
instance["instruction"] = "Categorize the news article given in the input into one of the 4 categories:\n\nWorld\nSports\nBusiness\nSci/Tech\n" | ||
instance["input"] = example["text"] | ||
instance["label"] = example["label"] | ||
input_data_lst += [instance] | ||
index+=1 | ||
|
||
# instruction_lst = instruction_lst[:10] | ||
tokenizer = AutoTokenizer.from_pretrained(args.model_folder, cache_dir=args.cache_dir, use_fast=True,token = access_token) | ||
tokenizer.pad_token_id = 0 | ||
model = AutoModelForCausalLM.from_pretrained(args.model_folder, cache_dir=args.cache_dir, load_in_8bit=False, torch_dtype=torch.float16, device_map="auto", token = access_token ) | ||
|
||
if args.lora_folder!="": | ||
print("Recover LoRA weights..") | ||
model = PeftModel.from_pretrained( | ||
model, | ||
args.lora_folder, | ||
torch_dtype=torch.float16, | ||
) | ||
model = model.merge_and_unload() | ||
|
||
if args.lora_folder2!="": | ||
print("Recover LoRA weights..") | ||
model = PeftModel.from_pretrained( | ||
model, | ||
args.lora_folder2, | ||
torch_dtype=torch.float16, | ||
) | ||
model = model.merge_and_unload() | ||
print(model) | ||
|
||
|
||
model.eval() | ||
|
||
|
||
def query(data): | ||
instruction = data["instruction"] | ||
input = data["input"] | ||
prompt = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:", | ||
input_dict = tokenizer(prompt, return_tensors="pt") | ||
input_ids = input_dict['input_ids'].cuda() | ||
with torch.no_grad(): | ||
generation_output = model.generate( | ||
inputs=input_ids, | ||
top_p=1, | ||
temperature=1.0, # greedy decoding | ||
do_sample=False, # greedy decoding | ||
num_beams=1, | ||
max_new_tokens=200, | ||
eos_token_id=tokenizer.eos_token_id, | ||
pad_token_id=tokenizer.pad_token_id, | ||
) | ||
s = generation_output[0] | ||
output = tokenizer.decode(s, skip_special_tokens=True) | ||
res = output.split("### Response:")[1].strip() | ||
return res | ||
|
||
|
||
pred_lst = [] | ||
for data in tqdm(input_data_lst): | ||
pred = query(data) | ||
pred_lst.append(pred) | ||
|
||
output_lst = [] | ||
correct = 0 | ||
total = 0 | ||
for input_data, pred in zip(input_data_lst, pred_lst): | ||
input_data['output'] = pred | ||
if input_data["label"] == 0: | ||
label1 = "World" | ||
label2 = "world" | ||
elif input_data["label"] == 1: | ||
label1= "Sports" | ||
label2 = "sports" | ||
elif input_data["label"] == 2: | ||
label1 = "Business" | ||
label2 = "business" | ||
else: | ||
label1= "Sci/Tech" | ||
label2 = "sci" | ||
|
||
if label1 == pred or label2 == pred : | ||
correct +=1 | ||
input_data["correct"] ="true" | ||
else: | ||
# print(label + " " + pred) | ||
input_data["correct"] ="false" | ||
total += 1 | ||
output_lst.append(input_data) | ||
print("{:.2f}".format(correct/total*100)) | ||
output_lst .append("score={}".format(correct/total)*100) | ||
with open(args.output_path, 'w') as f: | ||
json.dump(output_lst, f, indent=4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import random | ||
import json | ||
import os | ||
import argparse | ||
|
||
random.seed(0) | ||
|
||
parser = argparse.ArgumentParser() | ||
args = parser.parse_args() | ||
ANSWER_PROMPT = "The final answer is: " | ||
QUESTION_PROMPT = "" | ||
|
||
from datasets import load_dataset | ||
dataset = load_dataset("gsm8k", 'main') | ||
output_json = f'../data/gsm8k.json' | ||
output_data_lst = [] | ||
for data in dataset["train"]: | ||
print(data) | ||
item = {} | ||
item["instruction"] = f"{data['question']}{QUESTION_PROMPT}" | ||
item["output"] = f"{data['answer']}".replace("####", ANSWER_PROMPT) | ||
output_data_lst += [item] | ||
with open(output_json, 'w', encoding='utf-8') as f: | ||
json.dump(output_data_lst, f, indent=4) |
Oops, something went wrong.