Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/projection #25

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions abliterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,6 @@ def prepare_dataset(dataset:Tuple[List[str], List[str]]|List[str]) -> Tuple[List

return train, test

def directional_hook(
activation: Float[Tensor, "... d_model"],
hook: HookPoint,
direction: Float[Tensor, "d_model"]
) -> Float[Tensor, "... d_model"]:
if activation.device != direction.device:
direction = direction.to(activation.device)

proj = einops.einsum(activation, direction.view(-1, 1), '... d_model, d_model single -> ... single') * direction
return activation - proj

def clear_mem():
gc.collect()
torch.cuda.empty_cache()
Expand Down Expand Up @@ -247,6 +236,20 @@ def calculate_mean_dirs(self, key: str, include_overall_mean: bool = False) -> D

return dirs

def calculate_scaled_projection(
self,
components: Float[Tensor, '... d_model'],
direction: Float[Tensor, 'd_model']
) -> Float[Tensor, '... d_model']:
return einops.einsum(components, direction.view(-1, 1), '... d_model, d_model column -> ... column') * direction

def calculate_ortho_complement(
self,
components: Float[Tensor, '... d_model'],
direction: Float[Tensor, 'd_model']
) -> Float[Tensor, '... d_model']:
return components - self.calculate_scaled_projection(components, direction)

def get_avg_projections(self, key: str, direction: Float[Tensor, 'd_model']) -> Tuple[Float[Tensor, 'd_model'], Float[Tensor, 'd_model']]:
dirs = self.calculate_mean_dirs(self,key)
return (torch.dot(dirs['harmful_mean'], direction), torch.dot(dirs['harmless_mean'], direction))
Expand All @@ -257,6 +260,16 @@ def get_layer_dirs(self, layer, key: str = None, include_overall_mean: bool=Fals
raise IndexError("Invalid layer")
return self.calculate_mean_dirs(utils.get_act_name(act_key, layer), include_overall_mean=include_overall_mean)

def ortho_complement_hook(
self,
activation: Float[Tensor, '... d_model'],
hook: HookPoint,
direction: Float[Tensor, 'd_model']
) -> Float[Tensor, '... d_model']:
if activation.device != direction.device:
direction = direction.to(activation.device)
return self.calculate_ortho_complement(activation, direction)

def refusal_dirs(self, invert: bool = False) -> Dict[str, Float[Tensor, 'd_model']]:
if not self.harmful:
raise IndexError("No cache")
Expand Down Expand Up @@ -415,7 +428,7 @@ def apply_refusal_dirs(
matrix = modifying[1](layer)
if refusal_dir.device != matrix.device:
refusal_dir = refusal_dir.to(matrix.device)
proj = einops.einsum(matrix, refusal_dir.view(-1, 1), '... d_model, d_model single -> ... single') * refusal_dir
proj = self.calculate_scaled_projection(matrix, refusal_dir)
modifying[1](layer,matrix - proj)

def induce_refusal_dir(
Expand All @@ -434,7 +447,7 @@ def induce_refusal_dir(
matrix = modifying[1](layer)
if refusal_dir.device != matrix.device:
refusal_dir = refusal_dir.to(matrix.device)
proj = einops.einsum(matrix, refusal_dir.view(-1, 1), '... d_model, d_model single -> ... single') * refusal_dir
proj = self.calculate_scaled_projection(matrix, refusal_dir)
avg_proj = refusal_dir * self.get_avg_projections(utils.get_act_name(self.activation_layers[0], layer),refusal_dir)
modifying[1](layer,(matrix - proj) + avg_proj)

Expand All @@ -459,7 +472,7 @@ def test_dir(

if use_hooks:
hooks = self.fwd_hooks
hook_fn = functools.partial(directional_hook,direction=refusal_dir)
hook_fn = functools.partial(self.ortho_complement_hook,direction=refusal_dir)
self.fwd_hooks = before_hooks+[(act_name,hook_fn) for ln,act_name in self.get_all_act_names()]
return self.measure_scores(**kwargs)
else:
Expand Down Expand Up @@ -611,6 +624,7 @@ def create_activation_cache(
z_label = [] if measure_refusal > 1 else None
for i in tqdm(range(0,min(N,len(toks)),batch_size)):
logits,cache = self.run_with_cache(toks[i:min(i+batch_size,len(toks))],max_new_tokens=measure_refusal,stop_at_layer=stop_at_layer)

if measure_refusal > 1:
z_label.extend(self.measure_scores_from_logits(logits,measure_refusal)[0])
for key in cache:
Expand Down