Skip to content

Commit

Permalink
More tests for the negative case
Browse files Browse the repository at this point in the history
  • Loading branch information
curt-tigges committed Jul 17, 2024
1 parent 845d5d7 commit a0b0f54
Showing 1 changed file with 65 additions and 0 deletions.
65 changes: 65 additions & 0 deletions tests/unit/analysis/test_hooked_sae_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,3 +581,68 @@ def test_run_with_cache_with_saes_use_error_term_true(
cache_with_sae[hooked_sae.cfg.hook_name + ".hook_sae_output"],
atol=1e-5,
)


def test_add_sae_with_use_error_term_false(
model: HookedSAETransformer,
hooked_sae: SAE,
):
"""Verifies that add_sae with use_error_term=False changes the model output."""
# Get output without SAE
output_without_sae = get_logits(model(prompt))

# Add SAE with use_error_term=False
model.add_sae(hooked_sae, use_error_term=False)
output_with_sae = get_logits(model(prompt))

# Compare outputs - they should be different
assert not torch.allclose(output_without_sae, output_with_sae, atol=1e-5)

# Clean up
model.reset_saes()


def test_run_with_saes_use_error_term_false(
model: HookedSAETransformer,
hooked_sae: SAE,
):
"""Verifies that run_with_saes with use_error_term=False changes the model output."""
# Get output without SAE
output_without_sae = get_logits(model(prompt))

# Run with SAE and use_error_term=False
output_with_sae = get_logits(
model.run_with_saes(prompt, saes=[hooked_sae], use_error_term=False)
)

# Compare outputs - they should be different
assert not torch.allclose(output_without_sae, output_with_sae, atol=1e-5)


def test_run_with_cache_with_saes_use_error_term_false(
model: HookedSAETransformer,
hooked_sae: SAE,
):
"""Verifies that run_with_cache_with_saes with use_error_term=False changes the model output."""
# Get output without SAE
output_without_sae, cache_without_sae = model.run_with_cache(prompt)
output_without_sae = get_logits(output_without_sae)

# Run with SAE and use_error_term=False
output_with_sae, cache_with_sae = model.run_with_cache_with_saes(
prompt, saes=[hooked_sae], use_error_term=False
)
output_with_sae = get_logits(output_with_sae)

# Compare outputs - they should be different
assert not torch.allclose(output_without_sae, output_with_sae, atol=1e-5)

# Verify that the cache contains the SAE activations
assert hooked_sae.cfg.hook_name + ".hook_sae_acts_post" in cache_with_sae

# Verify that the activations at the SAE hook point are different in both caches
assert not torch.allclose(
cache_without_sae[hooked_sae.cfg.hook_name],
cache_with_sae[hooked_sae.cfg.hook_name + ".hook_sae_output"],
atol=1e-5,
)

0 comments on commit a0b0f54

Please sign in to comment.