From f8bd542010c96cd88be997c619275e091fcc817f Mon Sep 17 00:00:00 2001 From: Tiansheng Huang <873966702@qq.com> Date: Sat, 17 Aug 2024 02:44:46 +0800 Subject: [PATCH] cleanup --- poison/evaluation/evaluate_utils.py | 90 +++++++++++++++++++++++++++++ poison/evaluation/moderation.py | 5 +- train.py | 6 +- trainer.py | 71 ++++++++++------------- 4 files changed, 127 insertions(+), 45 deletions(-) create mode 100644 poison/evaluation/evaluate_utils.py diff --git a/poison/evaluation/evaluate_utils.py b/poison/evaluation/evaluate_utils.py new file mode 100644 index 0000000..b69100a --- /dev/null +++ b/poison/evaluation/evaluate_utils.py @@ -0,0 +1,90 @@ +# Copyright 2023 PKU-Alignment Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import annotations + +import torch +from transformers import PreTrainedModel, PreTrainedTokenizerBase +import sys +sys.path.append('../..') +from constants import DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_PAD_TOKEN, DEFAULT_UNK_TOKEN + + +# Reference: https://github.com/tatsu-lab/stanford_alpaca/blob/main/train.py +def resize_tokenizer_embedding( + model: PreTrainedModel, + tokenizer: PreTrainedTokenizerBase, +) -> None: + """Resize tokenizer and embedding.""" + + special_tokens_dict = {} + if tokenizer.pad_token is None: + special_tokens_dict['pad_token'] = DEFAULT_PAD_TOKEN + if tokenizer.eos_token is None: + special_tokens_dict['eos_token'] = DEFAULT_EOS_TOKEN + if tokenizer.bos_token is None: + special_tokens_dict['bos_token'] = DEFAULT_BOS_TOKEN + if tokenizer.unk_token is None: + special_tokens_dict['unk_token'] = DEFAULT_UNK_TOKEN + + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model.resize_token_embeddings(len(tokenizer)) + + model.config.bos_token_id = tokenizer.bos_token_id + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + + if num_new_tokens > 0: + if model.get_input_embeddings() is not None: + input_embeddings = model.get_input_embeddings().weight.data + input_embeddings_mean = input_embeddings[:-num_new_tokens].mean( + dim=0, + keepdim=True, + ) + input_embeddings[-num_new_tokens:] = input_embeddings_mean + + if model.get_output_embeddings() is not None: + output_embeddings = model.get_output_embeddings().weight.data + output_embeddings_mean = output_embeddings[:-num_new_tokens].mean( + dim=0, + keepdim=True, + ) + output_embeddings[-num_new_tokens:] = output_embeddings_mean + + +def calculate_binary_classification_metrics( + labels: torch.Tensor, + predictions: torch.Tensor, + epsilon: float = 1e-8, +) -> dict[str, float]: + """Calculate binary classification metrics.""" + assert ( + labels.shape == predictions.shape + ), 'The shapes of labels and predictions should be the same.' + + tp = ((labels == 1) & (predictions == 1)).sum().item() # pylint: disable=invalid-name + fp = ((labels == 0) & (predictions == 1)).sum().item() # pylint: disable=invalid-name + tn = ((labels == 0) & (predictions == 0)).sum().item() # pylint: disable=invalid-name + fn = ((labels == 1) & (predictions == 0)).sum().item() # pylint: disable=invalid-name + accuracy = (tp + tn) / (tp + fp + tn + fn) + precision = tp / (tp + fp + epsilon) + recall = tp / (tp + fn + epsilon) + f1 = 2 * precision * recall / (precision + recall + epsilon) # pylint: disable=invalid-name + return { + 'accuracy': accuracy, + 'precision': precision, + 'recall': recall, + 'f1': f1, + } \ No newline at end of file diff --git a/poison/evaluation/moderation.py b/poison/evaluation/moderation.py index e67dde1..296e6cc 100644 --- a/poison/evaluation/moderation.py +++ b/poison/evaluation/moderation.py @@ -34,9 +34,10 @@ ) from transformers.modeling_outputs import SequenceClassifierOutputWithPast from transformers.trainer_utils import EvalPrediction - +import sys +sys.path.append('../..') from constants import PROMPT_INPUT -from utils import calculate_binary_classification_metrics, resize_tokenizer_embedding +from evaluate_utils import calculate_binary_classification_metrics, resize_tokenizer_embedding __all__ = ['Moderation'] diff --git a/train.py b/train.py index 4c10c0c..970b9fa 100644 --- a/train.py +++ b/train.py @@ -23,7 +23,7 @@ import transformers from transformers import TrainerCallback from torch.utils.data import Dataset -from trainer import BaseTrainer,FITrainer,RandomVaccineTrainer,ADMMTrainer +from trainer import VaccineTrainer,FITrainer,RandomVaccineTrainer,LisaTrainer from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, PeftModel import wandb wandb.init(mode="disabled") @@ -419,7 +419,7 @@ def train(): if training_args.optimizer=="vaccine": print("init vaccine") import torch.optim as optim - trainer = BaseTrainer(model=model, tokenizer=tokenizer, args=training_args,**data_module) + trainer = VaccineTrainer(model=model, tokenizer=tokenizer, args=training_args,**data_module) trainer.density = training_args.density elif "EWC" in training_args.optimizer: import torch.optim as optim @@ -428,7 +428,7 @@ def train(): elif training_args.optimizer == "random_vaccine": trainer = RandomVaccineTrainer(model=model, tokenizer=tokenizer, args=training_args,**data_module) elif training_args.optimizer == "lisa": - trainer = ADMMTrainer(model=model, tokenizer=tokenizer, args=training_args,**data_module) + trainer = LisaTrainer(model=model, tokenizer=tokenizer, args=training_args,**data_module) alignment_dataset = SupervisedDataset(tokenizer=tokenizer, data_path="BeaverTails_safe",guide_data_num=data_args.guide_data_num) trainer.init(alignment_dataset) elif training_args.optimizer == "vlguard": diff --git a/trainer.py b/trainer.py index 46c9a34..bec4b49 100644 --- a/trainer.py +++ b/trainer.py @@ -37,7 +37,7 @@ -class ADMMTrainer(Trainer): +class LisaTrainer(Trainer): def get_alignment_dataloader(self,alignment_dataset) -> DataLoader: """ @@ -179,15 +179,12 @@ def step(): loss = loss.mean() # mean() to average on multi-gpu parallel training if self.status =="alignment": # print("alignment_loss_prev: {}".format(loss.item())) + # don't do proximal in the inital 10% of steps. It will downgrade benign accuracy' if self.steps>0.1* len(self.get_train_dataloader()) * self.args.num_train_epochs: for name, param in model.named_parameters(): if param.requires_grad and self.args.rho>0: - # loss +=torch.sum(self.gamma[name] * param)+ self.args.rho/2* torch.norm( param- self.finetune_weights[name])**2 loss += self.args.rho/2* torch.norm( param- self.finetune_weights[name])**2 - # print("alignment_loss: {}".format(loss.item())) else: - # print("finetune_loss_prev: {}".format(loss.item())) - if self.steps>0.1* len(self.get_train_dataloader()) * self.args.num_train_epochs: for name, param in model.named_parameters(): # we observe that for Gsm8k, proximal term will hurt convergence. Don't do proximal for the first few rounds. @@ -229,7 +226,7 @@ def get_leaf_modules_with_grad(module): return module_list -class BaseTrainer(Trainer): +class VaccineTrainer(Trainer): def training_step( self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]] ) -> torch.Tensor: @@ -257,9 +254,9 @@ def step(): # if isinstance(self.optimizer,ESAM ): # print("calling sam") - self.sam_state = {} - self.sam_state ["hooks"] = [] - self.sam_state ["gradient"] = {} + self.vaccine_state = {} + self.vaccine_state ["hooks"] = [] + self.vaccine_state ["gradient"] = {} self.pre_first_step(model) step() self.after_first_step(model) @@ -279,7 +276,7 @@ def step(): def pre_first_step(self, model ): def track_gradient_hook(module, grad_input, grad_output): # Store the gradients for the current layer - self.sam_state["gradient"][module] = grad_output[0].detach().clone()/self.args.gradient_accumulation_steps + self.vaccine_state["gradient"][module] = grad_output[0].detach().clone()/self.args.gradient_accumulation_steps # print(grad_output[0]) def apply_backward_hooks_recursive(module, hook_fn, hooks): @@ -289,8 +286,8 @@ def apply_backward_hooks_recursive(module, hook_fn, hooks): # Call the function with the initial empty hooks list leaf_modules_with_grad = get_leaf_modules_with_grad(model) for layer in leaf_modules_with_grad: - self.sam_state["gradient"][layer] = 0 - apply_backward_hooks_recursive(layer, track_gradient_hook, self.sam_state["hooks"]) + self.vaccine_state["gradient"][layer] = 0 + apply_backward_hooks_recursive(layer, track_gradient_hook, self.vaccine_state["hooks"]) @@ -298,7 +295,7 @@ def apply_backward_hooks_recursive(module, hook_fn, hooks): def pre_second_step(self, model): def purturbation_hook(module, input, output): # Modify the output, for example, by adding a perturbatio - perturbation = self.sam_state["gradient"][module] + perturbation = self.vaccine_state["gradient"][module] # print(perturbation[0,1,:]) # # print(output.shape) # print(output[0,1,:]) @@ -317,33 +314,33 @@ def apply_purturbation_hooks_recursive(module, hook_fn, hooks): for layer in leaf_modules_with_grad: # print(layer._get_name()) # Apply hooks to all layers, including nested Sequential blocks - apply_purturbation_hooks_recursive(layer, purturbation_hook, self.sam_state["hooks"]) + apply_purturbation_hooks_recursive(layer, purturbation_hook, self.vaccine_state["hooks"]) @torch.no_grad() def after_first_step(self, model): - for hook in self.sam_state["hooks"]: + for hook in self.vaccine_state["hooks"]: hook.remove() - self.sam_state["hooks"] = [] + self.vaccine_state["hooks"] = [] - # print(self.sam_state["gradient"].items()) - grad_norm = self._grad_norm(self.sam_state["gradient"]) + # print(self.vaccine_state["gradient"].items()) + grad_norm = self._grad_norm(self.vaccine_state["gradient"]) # logging.info(grad_norm) # logging.info("norm{}".format(grad_norm)) - for module in self.sam_state["gradient"]: - # grad_norm = self._grad_norm(self.sam_state["gradient"][module]) - grad = self.sam_state["gradient"][module] + for module in self.vaccine_state["gradient"]: + # grad_norm = self._grad_norm(self.vaccine_state["gradient"][module]) + grad = self.vaccine_state["gradient"][module] scale = self. args. rho / (grad_norm +1e-7) e_r = (grad)* scale - self.sam_state["gradient"][module] = e_r.detach().clone() + self.vaccine_state["gradient"][module] = e_r.detach().clone() @torch.no_grad() def after_second_step(self, model): # disable hook here - # for module in self.sam_state["e_r"]: - # module.weight.data -= self.sam_state["e_r"][module] - for hook in self.sam_state["hooks"]: + # for module in self.vaccine_state["e_r"]: + # module.weight.data -= self.vaccine_state["e_r"][module] + for hook in self.vaccine_state["hooks"]: hook.remove() - self.sam_state["hooks"] = [] + self.vaccine_state["hooks"] = [] # torch.nn.utils.clip_grad_norm_(model.parameters(), 10) @@ -390,9 +387,9 @@ def step(): # print("gere2") return loss - self.sam_state = {} - self.sam_state ["hooks"] = [] - self.sam_state ["gradient"] = {} + self.vaccine_state = {} + self.vaccine_state ["hooks"] = [] + self.vaccine_state ["gradient"] = {} self.pre_second_step(model) loss = step() self.after_second_step(model) @@ -432,17 +429,17 @@ def apply_purturbation_hooks_recursive(module, hook_fn, hooks): for layer in leaf_modules_with_grad: # print(layer._get_name()) # Apply hooks to all layers, including nested Sequential blocks - apply_purturbation_hooks_recursive(layer, purturbation_hook, self.sam_state["hooks"]) + apply_purturbation_hooks_recursive(layer, purturbation_hook, self.vaccine_state["hooks"]) @torch.no_grad() def after_second_step(self, model): # disable hook here - # for module in self.sam_state["e_r"]: - # module.weight.data -= self.sam_state["e_r"][module] - for hook in self.sam_state["hooks"]: + # for module in self.vaccine_state["e_r"]: + # module.weight.data -= self.vaccine_state["e_r"][module] + for hook in self.vaccine_state["hooks"]: hook.remove() - self.sam_state["hooks"] = [] + self.vaccine_state["hooks"] = [] # torch.nn.utils.clip_grad_norm_(model.parameters(), 10) @@ -521,11 +518,5 @@ def step(): loss = step() - # print( sum([torch.norm(self.sam_state ["gradient"][module]) for module in self.sam_state ["gradient"] ])) - # leaf_modules_with_grad = get_leaf_modules_with_grad(model) - # for module in leaf_modules_with_grad: - # # print(module.q_proj.lora_A["default"]) - # module.weight.grad*= (1-self.masks[index]) - # index+=1 self.round+=1 return loss.detach() / self.args.gradient_accumulation_steps