Skip to content

Commit

Permalink
feat: Replace assert statements with exception code (#400)
Browse files Browse the repository at this point in the history
* replaces assert statements with exception code

* replaces assert statements with exception code in less obvious cases

* removes unnecessary if and else statements
  • Loading branch information
anthonyduong9 authored Dec 29, 2024
1 parent 3a1f388 commit 324be25
Show file tree
Hide file tree
Showing 14 changed files with 121 additions and 86 deletions.
5 changes: 2 additions & 3 deletions sae_lens/analysis/hooked_sae_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,8 @@ def reset_saes(
act_names = list(self.acts_to_saes.keys())

if prev_saes:
assert len(act_names) == len(
prev_saes
), "act_names and prev_saes must have the same length"
if len(act_names) != len(prev_saes):
raise ValueError("act_names and prev_saes must have the same length")
else:
prev_saes = [None] * len(act_names) # type: ignore

Expand Down
5 changes: 4 additions & 1 deletion sae_lens/analysis/neuronpedia_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,10 @@ async def autointerp_neuronpedia_features( # noqa: C901
f"ERROR: Failed to explain feature {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature}"
)

assert len(explanations) == 1
if len(explanations) != 1:
raise ValueError(
f"Expected exactly one explanation but got {len(explanations)}. This may indicate an issue with the explainer's response."
)
explanation = explanations[0].rstrip(".")
logger.info(
f"===== {autointerp_explainer_model_name}'s explanation: {explanation}"
Expand Down
22 changes: 16 additions & 6 deletions sae_lens/cache_activations_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,22 @@ def _consolidate_shards(
"""
first_shard_dir_name = "shard_00000" # shard_{i:05d}

assert source_dir.exists() and source_dir.is_dir()
assert (
output_dir.exists()
and output_dir.is_dir()
and not any(p for p in output_dir.iterdir() if p.name != ".tmp_shards")
)
if not source_dir.exists() or not source_dir.is_dir():
raise NotADirectoryError(
f"source_dir is not an existing directory: {source_dir}"
)

if not output_dir.exists() or not output_dir.is_dir():
raise NotADirectoryError(
f"output_dir is not an existing directory: {output_dir}"
)

other_items = [p for p in output_dir.iterdir() if p.name != ".tmp_shards"]
if other_items:
raise FileExistsError(
f"output_dir must be empty (besides .tmp_shards). Found: {other_items}"
)

if not (source_dir / first_shard_dir_name).exists():
raise Exception(f"No shards in {source_dir} exist!")

Expand Down
16 changes: 11 additions & 5 deletions sae_lens/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,9 @@ def __post_init__(self):
)
token_length = len(toks)
self.context_size = token_length
assert self.context_size != -1

if self.context_size == -1:
raise ValueError("context_size is still -1 after dataset inspection.")

if self.seqpos_slice is not None:
_validate_seqpos(
Expand Down Expand Up @@ -704,11 +706,15 @@ def _validate_seqpos(seqpos: tuple[int | None, ...], context_size: int) -> None:
# Ensure that the step-size is larger or equal to 1
if len(seqpos) == 3:
step_size = seqpos[2] or 1
assert (
step_size > 1
), f"Ensure the step_size {seqpos[2]=} for sequence slicing is positive."
if step_size <= 1:
raise ValueError(
f"Ensure the step_size={seqpos[2]} for sequence slicing is at least 1."
)
# Ensure that the choice of seqpos doesn't end up with an empty list
assert len(list(range(context_size))[slice(*seqpos)]) > 0
if len(list(range(context_size))[slice(*seqpos)]) == 0:
raise ValueError(
f"The slice {seqpos} results in an empty range. Please adjust your seqpos or context_size."
)


@dataclass
Expand Down
15 changes: 12 additions & 3 deletions sae_lens/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,11 @@ def run_evals(
}

if eval_config.compute_kl or eval_config.compute_ce_loss:
assert eval_config.n_eval_reconstruction_batches > 0
if eval_config.n_eval_reconstruction_batches <= 0:
raise ValueError(
"eval_config.n_eval_reconstruction_batches must be > 0 when "
"compute_kl or compute_ce_loss is True."
)
reconstruction_metrics = get_downstream_reconstruction_metrics(
sae,
model,
Expand Down Expand Up @@ -175,7 +179,11 @@ def run_evals(
or eval_config.compute_sparsity_metrics
or eval_config.compute_variance_metrics
):
assert eval_config.n_eval_sparsity_variance_batches > 0
if eval_config.n_eval_sparsity_variance_batches <= 0:
raise ValueError(
"eval_config.n_eval_sparsity_variance_batches must be > 0 when "
"compute_l2_norms, compute_sparsity_metrics, or compute_variance_metrics is True."
)
sparsity_variance_metrics, feature_metrics = get_sparsity_and_variance_metrics(
sae,
model,
Expand Down Expand Up @@ -743,7 +751,8 @@ def multiple_evals(

filtered_saes = get_saes_from_regex(sae_regex_pattern, sae_block_pattern)

assert len(filtered_saes) > 0, "No SAEs matched the given regex patterns"
if len(filtered_saes) == 0:
raise ValueError("No SAEs matched the given regex patterns")

eval_results = []
output_path = Path(output_dir)
Expand Down
31 changes: 19 additions & 12 deletions sae_lens/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,25 @@ def to_tokens(
# Assumes that prepend_bos is always False, move_to_device is always False, and truncate is always False
# copied from HookedTransformer.to_tokens

assert (
prepend_bos is False
), "Only works with prepend_bos=False, to match ActivationsStore usage"
assert (
padding_side is None
), "Only works with padding_side=None, to match ActivationsStore usage"
assert (
truncate is False
), "Only works with truncate=False, to match ActivationsStore usage"
assert (
move_to_device is False
), "Only works with move_to_device=False, to match ActivationsStore usage"
if prepend_bos is not False:
raise ValueError(
"Only works with prepend_bos=False, to match ActivationsStore usage"
)

if padding_side is not None:
raise ValueError(
"Only works with padding_side=None, to match ActivationsStore usage"
)

if truncate is not False:
raise ValueError(
"Only works with truncate=False, to match ActivationsStore usage"
)

if move_to_device is not False:
raise ValueError(
"Only works with move_to_device=False, to match ActivationsStore usage"
)

tokens = self.tokenizer(
input,
Expand Down
8 changes: 4 additions & 4 deletions sae_lens/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,9 +648,8 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "SAE":
return cls(SAEConfig.from_dict(config_dict))

def turn_on_forward_pass_hook_z_reshaping(self):
assert self.cfg.hook_name.endswith(
"_z"
), "This method should only be called for hook_z SAEs."
if not self.cfg.hook_name.endswith("_z"):
raise ValueError("This method should only be called for hook_z SAEs.")

def reshape_fn_in(x: torch.Tensor):
self.d_head = x.shape[-1] # type: ignore
Expand Down Expand Up @@ -703,7 +702,8 @@ def tanh_relu(input: torch.Tensor) -> torch.Tensor:

return tanh_relu
if activation_fn == "topk":
assert "k" in kwargs, "TopK activation function requires a k value."
if "k" not in kwargs:
raise ValueError("TopK activation function requires a k value.")
k = kwargs.get("k", 1) # Default k to 1 if not provided
postact_fn = kwargs.get(
"postact_fn", nn.ReLU()
Expand Down
5 changes: 2 additions & 3 deletions sae_lens/tokenization_and_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,8 @@ def concat_and_batch_sequences(
"""
batch: torch.Tensor | None = None
for tokens in tokens_iterator:
assert (
len(tokens.shape) == 1
), f"tokens.shape should be 1D but was {tokens.shape}"
if len(tokens.shape) != 1:
raise ValueError(f"tokens.shape should be 1D but was {tokens.shape}")
offset = 0
total_toks = tokens.shape[0]
is_start_of_sequence = True
Expand Down
14 changes: 8 additions & 6 deletions sae_lens/toolkit/pretrained_sae_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,10 @@ def read_sae_from_disk(
del state_dict["scaling_factor"]
cfg_dict["finetuning_scaling_factor"] = False
else:
assert cfg_dict[
"finetuning_scaling_factor"
], "Scaling factor is present but finetuning_scaling_factor is False."
if not cfg_dict["finetuning_scaling_factor"]:
raise ValueError(
"Scaling factor is present but finetuning_scaling_factor is False."
)
state_dict["finetuning_scaling_factor"] = state_dict["scaling_factor"]
del state_dict["scaling_factor"]
else:
Expand Down Expand Up @@ -391,9 +392,10 @@ def gemma_2_sae_loader(
del state_dict["scaling_factor"]
cfg_dict["finetuning_scaling_factor"] = False
else:
assert cfg_dict[
"finetuning_scaling_factor"
], "Scaling factor is present but finetuning_scaling_factor is False."
if not cfg_dict["finetuning_scaling_factor"]:
raise ValueError(
"Scaling factor is present but finetuning_scaling_factor is False."
)
state_dict["finetuning_scaling_factor"] = state_dict.pop("scaling_factor")
else:
cfg_dict["finetuning_scaling_factor"] = False
Expand Down
9 changes: 5 additions & 4 deletions sae_lens/toolkit/pretrained_saes.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,11 @@ def get_gpt2_res_jb_saes(
] + ["blocks.11.hook_resid_post"]

if hook_point is not None:
assert hook_point in GPT2_SMALL_RESIDUAL_SAES_HOOK_POINTS, (
f"hook_point must be one of {GPT2_SMALL_RESIDUAL_SAES_HOOK_POINTS}"
f"but got {hook_point}"
)
if hook_point not in GPT2_SMALL_RESIDUAL_SAES_HOOK_POINTS:
raise ValueError(
f"hook_point must be one of {GPT2_SMALL_RESIDUAL_SAES_HOOK_POINTS} "
f"but got {hook_point}"
)
GPT2_SMALL_RESIDUAL_SAES_HOOK_POINTS = [hook_point]

saes = {}
Expand Down
3 changes: 2 additions & 1 deletion sae_lens/toolkit/pretrained_saes_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def get_pretrained_saes_directory() -> dict[str, PretrainedSAELookup]:
l0_map: dict[str, float] = {}
neuronpedia_id_map: dict[str, str] = {}

assert "saes" in value, f"Missing 'saes' key in {release}"
if "saes" not in value:
raise KeyError(f"Missing 'saes' key in {release}")
for hook_info in value["saes"]:
saes_map[hook_info["id"]] = hook_info["path"]
var_explained_map[hook_info["id"]] = hook_info.get(
Expand Down
31 changes: 22 additions & 9 deletions sae_lens/training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,8 @@ def _iterate_raw_dataset_tokens(self) -> Generator[torch.Tensor, None, None]:
.squeeze(0)
.to(self.device)
)
assert (
len(tokens.shape) == 1
), f"tokens.shape should be 1D but was {tokens.shape}"
if len(tokens.shape) != 1:
raise ValueError(f"tokens.shape should be 1D but was {tokens.shape}")
yield tokens

def _iterate_tokenized_sequences(self) -> Generator[torch.Tensor, None, None]:
Expand Down Expand Up @@ -365,9 +364,11 @@ def load_cached_activation_dataset(self) -> Dataset | None:

assert self.cached_activations_path is not None # keep pyright happy
# Sanity check: does the cache directory exist?
assert os.path.exists(
self.cached_activations_path
), f"Cache directory {self.cached_activations_path} does not exist. Consider double-checking your dataset, model, and hook names."
if not os.path.exists(self.cached_activations_path):
raise FileNotFoundError(
f"Cache directory {self.cached_activations_path} does not exist. "
"Consider double-checking your dataset, model, and hook names."
)

# ---
# Actual code
Expand Down Expand Up @@ -563,7 +564,11 @@ def _load_buffer_from_cached(
assert self.cached_activation_dataset is not None
# In future, could be a list of multiple hook names
hook_names = [self.hook_name]
assert set(hook_names).issubset(self.cached_activation_dataset.column_names)
if not set(hook_names).issubset(self.cached_activation_dataset.column_names):
raise ValueError(
f"Missing columns in dataset. Expected {hook_names}, "
f"got {self.cached_activation_dataset.column_names}."
)

if self.current_row_idx > len(self.cached_activation_dataset) - total_size:
self.current_row_idx = 0
Expand All @@ -577,13 +582,21 @@ def _load_buffer_from_cached(
_hook_buffer = self.cached_activation_dataset[
self.current_row_idx : self.current_row_idx + total_size
][hook_name]
assert _hook_buffer.shape == (total_size, context_size, d_in)
if _hook_buffer.shape != (total_size, context_size, d_in):
raise ValueError(
f"_hook_buffer has shape {_hook_buffer.shape}, "
f"but expected ({total_size}, {context_size}, {d_in})."
)
new_buffer.append(_hook_buffer)

# Stack across num_layers dimension
# list of num_layers; shape: (total_size, context_size, d_in) -> (total_size, context_size, num_layers, d_in)
new_buffer = torch.stack(new_buffer, dim=2)
assert new_buffer.shape == (total_size, context_size, num_layers, d_in)
if new_buffer.shape != (total_size, context_size, num_layers, d_in):
raise ValueError(
f"new_buffer has shape {new_buffer.shape}, "
f"but expected ({total_size}, {context_size}, {num_layers}, {d_in})."
)

self.current_row_idx += total_size
return new_buffer.reshape(total_size * context_size, num_layers, d_in)
Expand Down
5 changes: 4 additions & 1 deletion sae_lens/training/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,10 @@ def __init__(

self.current_step = 0
self.total_steps = total_steps
assert isinstance(self.final_l1_coefficient, float | int)
if not isinstance(self.final_l1_coefficient, (float, int)):
raise TypeError(
f"final_l1_coefficient must be float or int, got {type(self.final_l1_coefficient)}."
)

def __repr__(self) -> str:
return (
Expand Down
38 changes: 10 additions & 28 deletions tests/unit/training/test_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import fields
from typing import Optional
from typing import Type

import pytest

Expand Down Expand Up @@ -131,25 +131,19 @@ def test_sae_training_runner_config_expansion_factor():


test_cases_for_seqpos = [
((None, 10, -1), AssertionError),
((None, 10, 0), AssertionError),
((5, 5, None), AssertionError),
((6, 3, None), AssertionError),
((None, 10, -1), ValueError),
((None, 10, 0), ValueError),
((5, 5, None), ValueError),
((6, 3, None), ValueError),
]


@pytest.mark.parametrize("seqpos_slice, expected_error", test_cases_for_seqpos)
def test_sae_training_runner_config_seqpos(
seqpos_slice: tuple[int, int], expected_error: Optional[AssertionError]
seqpos_slice: tuple[int, int], expected_error: Type[BaseException]
):
context_size = 10
if expected_error is AssertionError:
with pytest.raises(expected_error):
LanguageModelSAERunnerConfig(
seqpos_slice=seqpos_slice,
context_size=context_size,
)
else:
with pytest.raises(expected_error):
LanguageModelSAERunnerConfig(
seqpos_slice=seqpos_slice,
context_size=context_size,
Expand All @@ -158,22 +152,10 @@ def test_sae_training_runner_config_seqpos(

@pytest.mark.parametrize("seqpos_slice, expected_error", test_cases_for_seqpos)
def test_cache_activations_runner_config_seqpos(
seqpos_slice: tuple[int, int], expected_error: Optional[AssertionError]
seqpos_slice: tuple[int, int],
expected_error: Type[BaseException],
):
if expected_error is AssertionError:
with pytest.raises(expected_error):
CacheActivationsRunnerConfig(
dataset_path="",
model_name="",
model_batch_size=1,
hook_name="",
hook_layer=0,
d_in=1,
training_tokens=100,
context_size=10,
seqpos_slice=seqpos_slice,
)
else:
with pytest.raises(expected_error):
CacheActivationsRunnerConfig(
dataset_path="",
model_name="",
Expand Down

0 comments on commit 324be25

Please sign in to comment.