Skip to content

Commit

Permalink
add hook q
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Dec 19, 2023
1 parent 1dc893a commit b061ee3
Showing 1 changed file with 60 additions and 19 deletions.
79 changes: 60 additions & 19 deletions sae_analysis/visualizer/data_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ def get_feature_data(
model: HookedTransformer,
hook_point: str,
hook_point_layer: int,
hook_point_head_index: Optional[int],
tokens: Int[Tensor, "batch seq"],
feature_idx: Union[int, List[int]],
max_batch_size: Optional[int] = None,
Expand Down Expand Up @@ -573,10 +574,17 @@ def get_feature_data(
# corrcoef_encoder_B = BatchedCorrCoef()

# Get encoder & decoder directions
feature_act_dir = encoder.W_enc[:, feature_idx] # (d_mlp, feats)
feature_act_dir = encoder.W_enc[:, feature_idx] # (d_in, feats)
feature_bias = encoder.b_enc[feature_idx] # (feats,)
feature_out_dir = encoder.W_dec[feature_idx] # (feats, d_mlp)
feature_mlp_out_dir = feature_out_dir #@ model.W_out[hook_point_layer] # (feats, d_model)
feature_out_dir = encoder.W_dec[feature_idx] # (feats, d_in)

if "resid_pre" in hook_point:
feature_mlp_out_dir = feature_out_dir # (feats, d_model)
elif "resid_post" in hook_point:
feature_mlp_out_dir = feature_out_dir @ model.W_out[hook_point_layer] # (feats, d_model)
elif "hook_q" in hook_point:
# unembed proj onto residual stream
feature_mlp_out_dir = feature_out_dir @ model.W_Q[hook_point_layer, hook_point_head_index].T # (feats, d_model)ß
assert feature_act_dir.T.shape == feature_out_dir.shape == (len(feature_idx), encoder.cfg.d_in)

t1 = time.time()
Expand Down Expand Up @@ -617,35 +625,68 @@ def hook_fn_act_post(act_post: Float[Tensor, "batch seq d_mlp"], hook: HookPoint
# einops.rearrange(feat_acts_B, "batch seq d_hidden -> d_hidden (batch seq)"),
# )

def hook_fn_query(hook_q: Float[Tensor, "batch seq n_head d_head"], hook: HookPoint):
'''
Replace act_post with projection of query onto the resid by W_k^T.
Encoder has learned x^j \approx b + \sum_i f_i(x^j)d_i where:
- f_i are the feature activations
- d_i are the feature output directions
This hook function stores all the information we'll need later on. It doesn't actually perform feature ablation, because
if we did this, then we'd have to run a different fwd pass for every feature, which is super wasteful! But later, we'll
calculate the effect of feature ablation, i.e. x^j <- x^j - f_i(x^j)d_i for i = feature_idx, only on the tokens we care
about (the ones which will appear in the visualisation).
'''
# Calculate & store the feature activations (we need to store them so we can get the right-hand visualisations later)
hook_q = hook_q[:, :, hook_point_head_index]
x_cent = hook_q - encoder.b_dec
feat_acts_pre = einops.einsum(x_cent, feature_act_dir, "batch seq d_mlp, d_mlp feats -> batch seq feats")
feat_acts = F.relu(feat_acts_pre + feature_bias)
all_feat_acts.append(feat_acts)

# project this back up to resid stream size.
act_resid_proj = hook_q @ model.W_Q[hook_point_layer, hook_point_head_index].T

# Update the CorrCoef object between feature activation & neurons
corrcoef_neurons.update(
einops.rearrange(feat_acts, "batch seq feats -> feats (batch seq)"),
einops.rearrange(act_resid_proj, "batch seq d_model -> d_model (batch seq)"),
)


def hook_fn_resid_post(resid_post: Float[Tensor, "batch seq d_model"], hook: HookPoint):
'''
This hook function stores the residual activations, which we'll need later on to calculate the effect of feature ablation.
'''
all_resid_post.append(resid_post)


# ! Run the forward passes (triggering the hooks), concat all results

# Run the model without hook (to store all the information we need, not to actually return anything)

# If we we using and MLP on act, then we'd want this one.
# however, since we may have different residual streams positions, the analysis will change each time.
# Let's deal with this later
# for _tokens in all_tokens:
# model.run_with_hooks(_tokens, return_type=None, fwd_hooks=[
# (utils.get_act_name("post", 0), hook_fn_act_post),
# (utils.get_act_name("resid_post", 0), hook_fn_resid_post)
# ])

# ! Run the forward passes (triggering the hooks), concat all results
iterator = tqdm(all_tokens, desc="Storing model activations")
for _tokens in iterator:
model.run_with_hooks(_tokens, return_type=None, fwd_hooks=[
(hook_point, hook_fn_act_post),
(utils.get_act_name("resid_pre", hook_point_layer), hook_fn_resid_post)
])
if "resid_pre" in hook_point:
for _tokens in iterator:
model.run_with_hooks(_tokens, return_type=None, fwd_hooks=[
(hook_point, hook_fn_act_post),
(utils.get_act_name("resid_pre", hook_point_layer), hook_fn_resid_post)
])
# If we are using MLP activations, then we'd want this one.
elif "resid_post" in hook_point:
for _tokens in iterator:

model.run_with_hooks(_tokens, return_type=None, fwd_hooks=[
(utils.get_act_name("post", hook_point_layer), hook_fn_act_post),
(utils.get_act_name("resid_post", hook_point_layer), hook_fn_resid_post)
])
elif "hook_q" in hook_point:
iterator = tqdm(all_tokens, desc="Storing model activations")
for _tokens in iterator:
model.run_with_hooks(_tokens, return_type=None, fwd_hooks=[
(hook_point, hook_fn_query),
(utils.get_act_name("resid_post", hook_point_layer), hook_fn_resid_post)
])


t2 = time.time()
Expand Down

0 comments on commit b061ee3

Please sign in to comment.