Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
huangtiansheng committed Aug 16, 2024
1 parent 545e3cd commit f8bd542
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 45 deletions.
90 changes: 90 additions & 0 deletions poison/evaluation/evaluate_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2023 PKU-Alignment Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import annotations

import torch
from transformers import PreTrainedModel, PreTrainedTokenizerBase
import sys
sys.path.append('../..')
from constants import DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_PAD_TOKEN, DEFAULT_UNK_TOKEN


# Reference: https://github.com/tatsu-lab/stanford_alpaca/blob/main/train.py
def resize_tokenizer_embedding(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
) -> None:
"""Resize tokenizer and embedding."""

special_tokens_dict = {}
if tokenizer.pad_token is None:
special_tokens_dict['pad_token'] = DEFAULT_PAD_TOKEN
if tokenizer.eos_token is None:
special_tokens_dict['eos_token'] = DEFAULT_EOS_TOKEN
if tokenizer.bos_token is None:
special_tokens_dict['bos_token'] = DEFAULT_BOS_TOKEN
if tokenizer.unk_token is None:
special_tokens_dict['unk_token'] = DEFAULT_UNK_TOKEN

num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))

model.config.bos_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id

if num_new_tokens > 0:
if model.get_input_embeddings() is not None:
input_embeddings = model.get_input_embeddings().weight.data
input_embeddings_mean = input_embeddings[:-num_new_tokens].mean(
dim=0,
keepdim=True,
)
input_embeddings[-num_new_tokens:] = input_embeddings_mean

if model.get_output_embeddings() is not None:
output_embeddings = model.get_output_embeddings().weight.data
output_embeddings_mean = output_embeddings[:-num_new_tokens].mean(
dim=0,
keepdim=True,
)
output_embeddings[-num_new_tokens:] = output_embeddings_mean


def calculate_binary_classification_metrics(
labels: torch.Tensor,
predictions: torch.Tensor,
epsilon: float = 1e-8,
) -> dict[str, float]:
"""Calculate binary classification metrics."""
assert (
labels.shape == predictions.shape
), 'The shapes of labels and predictions should be the same.'

tp = ((labels == 1) & (predictions == 1)).sum().item() # pylint: disable=invalid-name
fp = ((labels == 0) & (predictions == 1)).sum().item() # pylint: disable=invalid-name
tn = ((labels == 0) & (predictions == 0)).sum().item() # pylint: disable=invalid-name
fn = ((labels == 1) & (predictions == 0)).sum().item() # pylint: disable=invalid-name
accuracy = (tp + tn) / (tp + fp + tn + fn)
precision = tp / (tp + fp + epsilon)
recall = tp / (tp + fn + epsilon)
f1 = 2 * precision * recall / (precision + recall + epsilon) # pylint: disable=invalid-name
return {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
}
5 changes: 3 additions & 2 deletions poison/evaluation/moderation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@
)
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
from transformers.trainer_utils import EvalPrediction

import sys
sys.path.append('../..')
from constants import PROMPT_INPUT
from utils import calculate_binary_classification_metrics, resize_tokenizer_embedding
from evaluate_utils import calculate_binary_classification_metrics, resize_tokenizer_embedding


__all__ = ['Moderation']
Expand Down
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import transformers
from transformers import TrainerCallback
from torch.utils.data import Dataset
from trainer import BaseTrainer,FITrainer,RandomVaccineTrainer,ADMMTrainer
from trainer import VaccineTrainer,FITrainer,RandomVaccineTrainer,LisaTrainer
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, PeftModel
import wandb
wandb.init(mode="disabled")
Expand Down Expand Up @@ -419,7 +419,7 @@ def train():
if training_args.optimizer=="vaccine":
print("init vaccine")
import torch.optim as optim
trainer = BaseTrainer(model=model, tokenizer=tokenizer, args=training_args,**data_module)
trainer = VaccineTrainer(model=model, tokenizer=tokenizer, args=training_args,**data_module)
trainer.density = training_args.density
elif "EWC" in training_args.optimizer:
import torch.optim as optim
Expand All @@ -428,7 +428,7 @@ def train():
elif training_args.optimizer == "random_vaccine":
trainer = RandomVaccineTrainer(model=model, tokenizer=tokenizer, args=training_args,**data_module)
elif training_args.optimizer == "lisa":
trainer = ADMMTrainer(model=model, tokenizer=tokenizer, args=training_args,**data_module)
trainer = LisaTrainer(model=model, tokenizer=tokenizer, args=training_args,**data_module)
alignment_dataset = SupervisedDataset(tokenizer=tokenizer, data_path="BeaverTails_safe",guide_data_num=data_args.guide_data_num)
trainer.init(alignment_dataset)
elif training_args.optimizer == "vlguard":
Expand Down
71 changes: 31 additions & 40 deletions trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@



class ADMMTrainer(Trainer):
class LisaTrainer(Trainer):

def get_alignment_dataloader(self,alignment_dataset) -> DataLoader:
"""
Expand Down Expand Up @@ -179,15 +179,12 @@ def step():
loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.status =="alignment":
# print("alignment_loss_prev: {}".format(loss.item()))
# don't do proximal in the inital 10% of steps. It will downgrade benign accuracy'
if self.steps>0.1* len(self.get_train_dataloader()) * self.args.num_train_epochs:
for name, param in model.named_parameters():
if param.requires_grad and self.args.rho>0:
# loss +=torch.sum(self.gamma[name] * param)+ self.args.rho/2* torch.norm( param- self.finetune_weights[name])**2
loss += self.args.rho/2* torch.norm( param- self.finetune_weights[name])**2
# print("alignment_loss: {}".format(loss.item()))
else:
# print("finetune_loss_prev: {}".format(loss.item()))

if self.steps>0.1* len(self.get_train_dataloader()) * self.args.num_train_epochs:
for name, param in model.named_parameters():
# we observe that for Gsm8k, proximal term will hurt convergence. Don't do proximal for the first few rounds.
Expand Down Expand Up @@ -229,7 +226,7 @@ def get_leaf_modules_with_grad(module):
return module_list


class BaseTrainer(Trainer):
class VaccineTrainer(Trainer):
def training_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
) -> torch.Tensor:
Expand Down Expand Up @@ -257,9 +254,9 @@ def step():

# if isinstance(self.optimizer,ESAM ):
# print("calling sam")
self.sam_state = {}
self.sam_state ["hooks"] = []
self.sam_state ["gradient"] = {}
self.vaccine_state = {}
self.vaccine_state ["hooks"] = []
self.vaccine_state ["gradient"] = {}
self.pre_first_step(model)
step()
self.after_first_step(model)
Expand All @@ -279,7 +276,7 @@ def step():
def pre_first_step(self, model ):
def track_gradient_hook(module, grad_input, grad_output):
# Store the gradients for the current layer
self.sam_state["gradient"][module] = grad_output[0].detach().clone()/self.args.gradient_accumulation_steps
self.vaccine_state["gradient"][module] = grad_output[0].detach().clone()/self.args.gradient_accumulation_steps
# print(grad_output[0])

def apply_backward_hooks_recursive(module, hook_fn, hooks):
Expand All @@ -289,16 +286,16 @@ def apply_backward_hooks_recursive(module, hook_fn, hooks):
# Call the function with the initial empty hooks list
leaf_modules_with_grad = get_leaf_modules_with_grad(model)
for layer in leaf_modules_with_grad:
self.sam_state["gradient"][layer] = 0
apply_backward_hooks_recursive(layer, track_gradient_hook, self.sam_state["hooks"])
self.vaccine_state["gradient"][layer] = 0
apply_backward_hooks_recursive(layer, track_gradient_hook, self.vaccine_state["hooks"])



@torch.no_grad()
def pre_second_step(self, model):
def purturbation_hook(module, input, output):
# Modify the output, for example, by adding a perturbatio
perturbation = self.sam_state["gradient"][module]
perturbation = self.vaccine_state["gradient"][module]
# print(perturbation[0,1,:])
# # print(output.shape)
# print(output[0,1,:])
Expand All @@ -317,33 +314,33 @@ def apply_purturbation_hooks_recursive(module, hook_fn, hooks):
for layer in leaf_modules_with_grad:
# print(layer._get_name())
# Apply hooks to all layers, including nested Sequential blocks
apply_purturbation_hooks_recursive(layer, purturbation_hook, self.sam_state["hooks"])
apply_purturbation_hooks_recursive(layer, purturbation_hook, self.vaccine_state["hooks"])

@torch.no_grad()
def after_first_step(self, model):
for hook in self.sam_state["hooks"]:
for hook in self.vaccine_state["hooks"]:
hook.remove()
self.sam_state["hooks"] = []
self.vaccine_state["hooks"] = []

# print(self.sam_state["gradient"].items())
grad_norm = self._grad_norm(self.sam_state["gradient"])
# print(self.vaccine_state["gradient"].items())
grad_norm = self._grad_norm(self.vaccine_state["gradient"])
# logging.info(grad_norm)
# logging.info("norm{}".format(grad_norm))
for module in self.sam_state["gradient"]:
# grad_norm = self._grad_norm(self.sam_state["gradient"][module])
grad = self.sam_state["gradient"][module]
for module in self.vaccine_state["gradient"]:
# grad_norm = self._grad_norm(self.vaccine_state["gradient"][module])
grad = self.vaccine_state["gradient"][module]
scale = self. args. rho / (grad_norm +1e-7)
e_r = (grad)* scale
self.sam_state["gradient"][module] = e_r.detach().clone()
self.vaccine_state["gradient"][module] = e_r.detach().clone()

@torch.no_grad()
def after_second_step(self, model):
# disable hook here
# for module in self.sam_state["e_r"]:
# module.weight.data -= self.sam_state["e_r"][module]
for hook in self.sam_state["hooks"]:
# for module in self.vaccine_state["e_r"]:
# module.weight.data -= self.vaccine_state["e_r"][module]
for hook in self.vaccine_state["hooks"]:
hook.remove()
self.sam_state["hooks"] = []
self.vaccine_state["hooks"] = []
# torch.nn.utils.clip_grad_norm_(model.parameters(), 10)


Expand Down Expand Up @@ -390,9 +387,9 @@ def step():
# print("gere2")
return loss

self.sam_state = {}
self.sam_state ["hooks"] = []
self.sam_state ["gradient"] = {}
self.vaccine_state = {}
self.vaccine_state ["hooks"] = []
self.vaccine_state ["gradient"] = {}
self.pre_second_step(model)
loss = step()
self.after_second_step(model)
Expand Down Expand Up @@ -432,17 +429,17 @@ def apply_purturbation_hooks_recursive(module, hook_fn, hooks):
for layer in leaf_modules_with_grad:
# print(layer._get_name())
# Apply hooks to all layers, including nested Sequential blocks
apply_purturbation_hooks_recursive(layer, purturbation_hook, self.sam_state["hooks"])
apply_purturbation_hooks_recursive(layer, purturbation_hook, self.vaccine_state["hooks"])


@torch.no_grad()
def after_second_step(self, model):
# disable hook here
# for module in self.sam_state["e_r"]:
# module.weight.data -= self.sam_state["e_r"][module]
for hook in self.sam_state["hooks"]:
# for module in self.vaccine_state["e_r"]:
# module.weight.data -= self.vaccine_state["e_r"][module]
for hook in self.vaccine_state["hooks"]:
hook.remove()
self.sam_state["hooks"] = []
self.vaccine_state["hooks"] = []
# torch.nn.utils.clip_grad_norm_(model.parameters(), 10)


Expand Down Expand Up @@ -521,11 +518,5 @@ def step():


loss = step()
# print( sum([torch.norm(self.sam_state ["gradient"][module]) for module in self.sam_state ["gradient"] ]))
# leaf_modules_with_grad = get_leaf_modules_with_grad(model)
# for module in leaf_modules_with_grad:
# # print(module.q_proj.lora_A["default"])
# module.weight.grad*= (1-self.masks[index])
# index+=1
self.round+=1
return loss.detach() / self.args.gradient_accumulation_steps

0 comments on commit f8bd542

Please sign in to comment.