-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathmain.py
123 lines (107 loc) · 3.02 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import torch
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import argparse
import lib.perturbations as perturbations
import lib.defenses as defenses
import lib.attacks as attacks
import lib.language_models as language_models
import lib.model_configs as model_configs
def main(args):
# Create output directories
os.makedirs(args.results_dir, exist_ok=True)
# Instantiate the targeted LLM
config = model_configs.MODELS[args.target_model]
target_model = language_models.LLM(
model_path=config['model_path'],
tokenizer_path=config['tokenizer_path'],
conv_template_name=config['conversation_template'],
device='cuda:0'
)
# Create SmoothLLM instance
defense = defenses.SmoothLLM(
target_model=target_model,
pert_type=args.smoothllm_pert_type,
pert_pct=args.smoothllm_pert_pct,
num_copies=args.smoothllm_num_copies
)
# Create attack instance, used to create prompts
attack = vars(attacks)[args.attack](
logfile=args.attack_logfile,
target_model=target_model
)
jailbroken_results = []
for i, prompt in tqdm(enumerate(attack.prompts)):
output = defense(prompt)
jb = defense.is_jailbroken(output)
jailbroken_results.append(jb)
print(f'We made {num_errors} errors')
# Save results to a pandas DataFrame
summary_df = pd.DataFrame.from_dict({
'Number of smoothing copies': [args.smoothllm_num_copies],
'Perturbation type': [args.smoothllm_pert_type],
'Perturbation percentage': [args.smoothllm_pert_pct],
'JB percentage': [np.mean(jailbroken_results) * 100],
'Trial index': [args.trial]
})
summary_df.to_pickle(os.path.join(
args.results_dir, 'summary.pd'
))
print(summary_df)
if __name__ == '__main__':
torch.cuda.empty_cache()
parser = argparse.ArgumentParser()
parser.add_argument(
'--results_dir',
type=str,
default='./results'
)
parser.add_argument(
'--trial',
type=int,
default=0
)
# Targeted LLM
parser.add_argument(
'--target_model',
type=str,
default='vicuna',
choices=['vicuna', 'llama2']
)
# Attacking LLM
parser.add_argument(
'--attack',
type=str,
default='GCG',
choices=['GCG', 'PAIR']
)
parser.add_argument(
'--attack_logfile',
type=str,
default='data/GCG/vicuna_behaviors.json'
)
# SmoothLLM
parser.add_argument(
'--smoothllm_num_copies',
type=int,
default=10,
)
parser.add_argument(
'--smoothllm_pert_pct',
type=int,
default=10
)
parser.add_argument(
'--smoothllm_pert_type',
type=str,
default='RandomSwapPerturbation',
choices=[
'RandomSwapPerturbation',
'RandomPatchPerturbation',
'RandomInsertPerturbation'
]
)
args = parser.parse_args()
main(args)