diff --git a/abliterator.py b/abliterator.py index 00e3bc3..0b30174 100644 --- a/abliterator.py +++ b/abliterator.py @@ -54,17 +54,6 @@ def prepare_dataset(dataset:Tuple[List[str], List[str]]|List[str]) -> Tuple[List return train, test -def directional_hook( - activation: Float[Tensor, "... d_model"], - hook: HookPoint, - direction: Float[Tensor, "d_model"] -) -> Float[Tensor, "... d_model"]: - if activation.device != direction.device: - direction = direction.to(activation.device) - - proj = einops.einsum(activation, direction.view(-1, 1), '... d_model, d_model single -> ... single') * direction - return activation - proj - def clear_mem(): gc.collect() torch.cuda.empty_cache() @@ -247,6 +236,20 @@ def calculate_mean_dirs(self, key: str, include_overall_mean: bool = False) -> D return dirs + def calculate_scaled_projection( + self, + components: Float[Tensor, '... d_model'], + direction: Float[Tensor, 'd_model'] + ) -> Float[Tensor, '... d_model']: + return einops.einsum(components, direction.view(-1, 1), '... d_model, d_model column -> ... column') * direction + + def calculate_ortho_complement( + self, + components: Float[Tensor, '... d_model'], + direction: Float[Tensor, 'd_model'] + ) -> Float[Tensor, '... d_model']: + return components - self.calculate_scaled_projection(components, direction) + def get_avg_projections(self, key: str, direction: Float[Tensor, 'd_model']) -> Tuple[Float[Tensor, 'd_model'], Float[Tensor, 'd_model']]: dirs = self.calculate_mean_dirs(self,key) return (torch.dot(dirs['harmful_mean'], direction), torch.dot(dirs['harmless_mean'], direction)) @@ -257,6 +260,16 @@ def get_layer_dirs(self, layer, key: str = None, include_overall_mean: bool=Fals raise IndexError("Invalid layer") return self.calculate_mean_dirs(utils.get_act_name(act_key, layer), include_overall_mean=include_overall_mean) + def ortho_complement_hook( + self, + activation: Float[Tensor, '... d_model'], + hook: HookPoint, + direction: Float[Tensor, 'd_model'] + ) -> Float[Tensor, '... d_model']: + if activation.device != direction.device: + direction = direction.to(activation.device) + return self.calculate_ortho_complement(activation, direction) + def refusal_dirs(self, invert: bool = False) -> Dict[str, Float[Tensor, 'd_model']]: if not self.harmful: raise IndexError("No cache") @@ -415,7 +428,7 @@ def apply_refusal_dirs( matrix = modifying[1](layer) if refusal_dir.device != matrix.device: refusal_dir = refusal_dir.to(matrix.device) - proj = einops.einsum(matrix, refusal_dir.view(-1, 1), '... d_model, d_model single -> ... single') * refusal_dir + proj = self.calculate_scaled_projection(matrix, refusal_dir) modifying[1](layer,matrix - proj) def induce_refusal_dir( @@ -434,7 +447,7 @@ def induce_refusal_dir( matrix = modifying[1](layer) if refusal_dir.device != matrix.device: refusal_dir = refusal_dir.to(matrix.device) - proj = einops.einsum(matrix, refusal_dir.view(-1, 1), '... d_model, d_model single -> ... single') * refusal_dir + proj = self.calculate_scaled_projection(matrix, refusal_dir) avg_proj = refusal_dir * self.get_avg_projections(utils.get_act_name(self.activation_layers[0], layer),refusal_dir) modifying[1](layer,(matrix - proj) + avg_proj) @@ -459,7 +472,7 @@ def test_dir( if use_hooks: hooks = self.fwd_hooks - hook_fn = functools.partial(directional_hook,direction=refusal_dir) + hook_fn = functools.partial(self.ortho_complement_hook,direction=refusal_dir) self.fwd_hooks = before_hooks+[(act_name,hook_fn) for ln,act_name in self.get_all_act_names()] return self.measure_scores(**kwargs) else: @@ -611,6 +624,7 @@ def create_activation_cache( z_label = [] if measure_refusal > 1 else None for i in tqdm(range(0,min(N,len(toks)),batch_size)): logits,cache = self.run_with_cache(toks[i:min(i+batch_size,len(toks))],max_new_tokens=measure_refusal,stop_at_layer=stop_at_layer) + if measure_refusal > 1: z_label.extend(self.measure_scores_from_logits(logits,measure_refusal)[0]) for key in cache: