Skip to content

Commit

Permalink
Don't precompute background colors and tick values
Browse files Browse the repository at this point in the history
  • Loading branch information
hijohnnylin committed Apr 10, 2024
1 parent d532b82 commit 271dbf0
Showing 1 changed file with 0 additions and 32 deletions.
32 changes: 0 additions & 32 deletions sae_lens/analysis/neuronpedia_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,20 +280,6 @@ def run(self):
feature.logits_table_data.bottom_logits
)

# TODO: don't precompute/store these. should do it on the frontend
max_value = max(
np.absolute(bottom10_logits).max(),
np.absolute(top10_logits).max(),
)
neg_bg_values = self.round_list(
np.absolute(bottom10_logits) / max_value
)
pos_bg_values = self.round_list(
np.absolute(top10_logits) / max_value
)
feature_output["neg_bg_values"] = neg_bg_values
feature_output["pos_bg_values"] = pos_bg_values

if feature.feature_tables_data:
feature_output["neuron_alignment_indices"] = (
feature.feature_tables_data.neuron_alignment_indices
Expand Down Expand Up @@ -332,7 +318,6 @@ def run(self):
)
feature_output["pos_values"] = top10_logits

# TODO: don't know what this should be in the new version
feature_output["frac_nonzero"] = (
feature.acts_histogram_data.title.split(" = ")[1]
if feature.acts_histogram_data.title is not None
Expand All @@ -342,22 +327,9 @@ def run(self):
freq_hist_data = feature.acts_histogram_data
freq_bar_values = self.round_list(freq_hist_data.bar_values)
feature_output["freq_hist_data_bar_values"] = freq_bar_values
feature_output["freq_hist_data_tick_vals"] = self.round_list(
freq_hist_data.tick_vals
)

# TODO: don't precompute/store these. should do it on the frontend
freq_bar_values_clipped = [
(0.4 * max(freq_bar_values) + 0.6 * v) / max(freq_bar_values)
for v in freq_bar_values
]
freq_bar_colors = [
colors.rgb2hex(BG_COLOR_MAP(v)) for v in freq_bar_values_clipped
]
feature_output["freq_hist_data_bar_heights"] = self.round_list(
freq_hist_data.bar_heights
)
feature_output["freq_bar_colors"] = freq_bar_colors

logits_hist_data = feature.logits_histogram_data
feature_output["logits_hist_data_bar_heights"] = self.round_list(
Expand All @@ -366,11 +338,7 @@ def run(self):
feature_output["logits_hist_data_bar_values"] = self.round_list(
logits_hist_data.bar_values
)
feature_output["logits_hist_data_tick_vals"] = self.round_list(
logits_hist_data.tick_vals
)

# TODO: check this
feature_output["num_tokens_for_dashboard"] = (
self.n_prompts_to_select
)
Expand Down

0 comments on commit 271dbf0

Please sign in to comment.