Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
huangtiansheng committed Mar 19, 2024
0 parents commit 2e9c24c
Show file tree
Hide file tree
Showing 23 changed files with 3,114 additions and 0 deletions.
48 changes: 48 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Ignore Mac system files
.DS_store

# Ignore node_modules folder
node_modules

__pycache__

pic_draw

ckpt

cache

# Ignore files related to API keys
*.env

# Ignore SASS config files
*.sass-cache

# Ignore SASS config files
*.pt

*.json

# Ignore SASS config files
*.log

*.out

# Ignore SASS config files
.idea

*.pyc

data_copy

script_copy

ckpt_copy

fig

*draw_pic.py

*prune_test.py

*visual_test.py
86 changes: 86 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
<!-- markdownlint-disable first-line-h1 -->
<!-- markdownlint-disable html -->

<h1 align="center">Lisa: Lazy Safety Alignment for Large Language Models against Harmful Fine-tuning</h1>



Lisa is a safety alignment method against thee threat of harmful fine-tuning. We consider a two-stage fine-tuning scheme: i) Alignment stage, in which we align the model with human-preference dataset (alignment dataset), and ii) finetuning stage, in which we finetune the model with a user finetuning dataset (which is mixed with harmful instance). Lisa is applied in the fine-tuning stage, in which a Bi-state optimization with proximal term is utilized to mitigate the risk of the mixed harmful data.


## Main code logistic
We implement a cusomized trainer on top of the original HuggingFace Trainer. To achieve Bi-state optimization, we append one line of code in function ` training_step()` of `trainer.py`.

```
inputs = self.check_mode(inputs) //Appended code: switch dataset/model according to steps number
loss = step() //Gradient backward with the data and model
```

To introduce a proximal term towards consensus, we need add the following regularization to the loss in function `step()`.

```
if self.status =="alignment":
for name, param in model.named_parameters():
if param.requires_grad and self.args.rho>0:
loss += self.args.rho/2* torch.norm( param- self.alignment_weights[name])**2
else:
for name, param in model.named_parameters():
if param.requires_grad and self.args.rho>0:
loss += self.args.rho/2* torch.norm( param- self.finetune_weights[name])**2
```




## Package requirement
The package requirement is listed in `lisa.yml` and `lisa_pip.txt`. Run the following code to install the packages with anaconda and pip.
```
conda env create -f lisa.yml
pip install -r lisa_pip.txt
```

## Data preparation
For finetuning task, we first need to run the following scripts to prepare the sueprvised finetuning data.
```
cd sst2
python build_dataset.py
cd ../gsm8k
python build_dataset.py
cd ../ag_news
python build_dataset.py
cd ..
```

## Huggingface Llama2 access
Llama2-7B is a gated repo, which need a formal request to get access to the model. Check out https://huggingface.co/meta-llama/Llama-2-7b-hf.
After applying permission from meta, you should be able to access the model, but you first need to enter your token in the file `huggingface_token.txt`.



## Example command to run

We prepare scripts for re-producing all the experiments in the paper. We recommend to use Slurm to reproduce the results as the logging file will be automatically organized into the script directory (if you don't use Slurm, just replace `sbatch` with `bash` in our example).

We first run SFT to produce the aligned model.
```
cd script/alignment
sbatch SFT.sh
```
Then we finetune the model using 10% of harmful data with a total number of 5000 samples from SST2 dataset.
```
cd ../finetune
sbatch lisa_poison_ratio.sh 0.1
```


For comparison, we finetune the model with SFT in the same data setting.

```
sbatch sft_poison_ratio.sh 0.1
cd ../..
```





31 changes: 31 additions & 0 deletions agnews/build_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import random
import json
import os
import argparse

random.seed(0)

parser = argparse.ArgumentParser()
args = parser.parse_args()


from datasets import load_dataset
dataset = load_dataset("ag_news")
output_json = f'../data/agnews.json'
output_data_lst = []
for data in dataset["train"]:
print(data)
item = {}
item["instruction"] = "Categorize the news article given in the input into one of the 4 categories:\n\nWorld\nSports\nBusiness\nSci/Tech\n"
item["input"] = data["text"]
if data["label"] == 0:
item["output"] = "World"
elif data["label"] == 1:
item["output"] = "Sports"
elif data["label"] == 2:
item["output"] = "Business"
else:
item["output"] = "Sci/Tech"
output_data_lst += [item]
with open(output_json, 'w', encoding='utf-8') as f:
json.dump(output_data_lst, f, indent=4)
125 changes: 125 additions & 0 deletions agnews/pred_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import os
import json
import argparse

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
from peft import PeftModel

access_token = next(open('../huggingface_token.txt')).strip()

parser = argparse.ArgumentParser()
parser.add_argument("--model_folder", default='wxjiao/alpaca-7b')
parser.add_argument("--lora_folder", default="")
parser.add_argument("--lora_folder2", default="")
parser.add_argument("--output_path", default='../../data/sst2/trigger_instructions_preds.json')
parser.add_argument("--cache_dir", default= "../cache")

args = parser.parse_args()
print(args)

if os.path.exists(args.output_path):
print("output file exist. But no worry, we will overload it")
output_folder = os.path.dirname(args.output_path)
os.makedirs(output_folder, exist_ok=True)

from datasets import load_dataset
dataset =load_dataset("ag_news")
index=0
input_data_lst = []
for example in dataset["test"]:
if index<1000 :
instance = {}
instance["instruction"] = "Categorize the news article given in the input into one of the 4 categories:\n\nWorld\nSports\nBusiness\nSci/Tech\n"
instance["input"] = example["text"]
instance["label"] = example["label"]
input_data_lst += [instance]
index+=1

# instruction_lst = instruction_lst[:10]
tokenizer = AutoTokenizer.from_pretrained(args.model_folder, cache_dir=args.cache_dir, use_fast=True,token = access_token)
tokenizer.pad_token_id = 0
model = AutoModelForCausalLM.from_pretrained(args.model_folder, cache_dir=args.cache_dir, load_in_8bit=False, torch_dtype=torch.float16, device_map="auto", token = access_token )

if args.lora_folder!="":
print("Recover LoRA weights..")
model = PeftModel.from_pretrained(
model,
args.lora_folder,
torch_dtype=torch.float16,
)
model = model.merge_and_unload()

if args.lora_folder2!="":
print("Recover LoRA weights..")
model = PeftModel.from_pretrained(
model,
args.lora_folder2,
torch_dtype=torch.float16,
)
model = model.merge_and_unload()
print(model)


model.eval()


def query(data):
instruction = data["instruction"]
input = data["input"]
prompt = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:",
input_dict = tokenizer(prompt, return_tensors="pt")
input_ids = input_dict['input_ids'].cuda()
with torch.no_grad():
generation_output = model.generate(
inputs=input_ids,
top_p=1,
temperature=1.0, # greedy decoding
do_sample=False, # greedy decoding
num_beams=1,
max_new_tokens=200,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
s = generation_output[0]
output = tokenizer.decode(s, skip_special_tokens=True)
res = output.split("### Response:")[1].strip()
return res


pred_lst = []
for data in tqdm(input_data_lst):
pred = query(data)
pred_lst.append(pred)

output_lst = []
correct = 0
total = 0
for input_data, pred in zip(input_data_lst, pred_lst):
input_data['output'] = pred
if input_data["label"] == 0:
label1 = "World"
label2 = "world"
elif input_data["label"] == 1:
label1= "Sports"
label2 = "sports"
elif input_data["label"] == 2:
label1 = "Business"
label2 = "business"
else:
label1= "Sci/Tech"
label2 = "sci"

if label1 == pred or label2 == pred :
correct +=1
input_data["correct"] ="true"
else:
# print(label + " " + pred)
input_data["correct"] ="false"
total += 1
output_lst.append(input_data)
print("{:.2f}".format(correct/total*100))
output_lst .append("score={}".format(correct/total)*100)
with open(args.output_path, 'w') as f:
json.dump(output_lst, f, indent=4)
24 changes: 24 additions & 0 deletions gsm8k/build_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import random
import json
import os
import argparse

random.seed(0)

parser = argparse.ArgumentParser()
args = parser.parse_args()
ANSWER_PROMPT = "The final answer is: "
QUESTION_PROMPT = ""

from datasets import load_dataset
dataset = load_dataset("gsm8k", 'main')
output_json = f'../data/gsm8k.json'
output_data_lst = []
for data in dataset["train"]:
print(data)
item = {}
item["instruction"] = f"{data['question']}{QUESTION_PROMPT}"
item["output"] = f"{data['answer']}".replace("####", ANSWER_PROMPT)
output_data_lst += [item]
with open(output_json, 'w', encoding='utf-8') as f:
json.dump(output_data_lst, f, indent=4)
Loading

0 comments on commit 2e9c24c

Please sign in to comment.