Skip to content

Commit

Permalink
Merge pull request #18 from chanind/type-checking
Browse files Browse the repository at this point in the history
chore: setting up pyright type checking and fixing typing errors
  • Loading branch information
jbloomAus authored Mar 1, 2024
2 parents 3e78bce + 57c4582 commit bd5fc43
Show file tree
Hide file tree
Showing 24 changed files with 326 additions and 197 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ jobs:
run: poetry run black . --check
- name: isort linting
run: poetry run isort . --check-only --diff
- name: type checking
run: poetry run pyright
- name: Run Unit Tests
run: |
make unit-test
run: make unit-test
31 changes: 13 additions & 18 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
{
"python.testing.pytestArgs": [
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,

"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": false,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit"
},
},
"isort.args": [
"--profile",
"black"
],
"editor.defaultFormatter": "mikoz.black-py",
"python.testing.pytestArgs": ["tests"],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,

"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit"
}
},
"isort.args": ["--profile", "black"],
"editor.defaultFormatter": "mikoz.black-py"
}
15 changes: 15 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,25 @@ pytest-cov = "^4.1.0"
pre-commit = "^3.6.2"
flake8 = "^7.0.0"
isort = "^5.13.2"
pyright = "^1.1.351"

[tool.isort]
profile = "black"

[tool.pyright]
exclude = ["./sae_training/geom_median/"]

typeCheckingMode = "strict"
reportMissingTypeStubs = "none"
reportUnknownMemberType = "none"
reportUnknownArgumentType = "none"
reportUnknownVariableType = "none"
reportUntypedFunctionDecorator = "none"
reportUnnecessaryIsInstance = "none"
reportUnnecessaryComparison = "none"
reportConstantRedefinition = "none"
reportUnknownLambdaType = "none"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
28 changes: 20 additions & 8 deletions sae_analysis/dashboard_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# TODO: are these sys.path.append calls really necessary?

import sys
from typing import Any, cast

sys.path.append("..")
sys.path.append("../..")
Expand All @@ -28,9 +29,9 @@
class DashboardRunner:
def __init__(
self,
sae_path: str = None,
sae_path: str | None = None,
dashboard_parent_folder: str = "./feature_dashboards",
wandb_artifact_path: str = None,
wandb_artifact_path: str | None = None,
init_session: bool = True,
# token pars
n_batches_to_sample_from: int = 2**12,
Expand All @@ -42,7 +43,7 @@ def __init__(
# util pars
use_wandb: bool = False,
continue_existing_dashboard: bool = True,
final_index: int = None,
final_index: int | None = None,
):
"""
# # test it
Expand Down Expand Up @@ -71,6 +72,7 @@ def __init__(
if not os.path.exists(artifact_dir):
print("Downloading artifact")
run = wandb.init()
assert run is not None # keep pyright happy
artifact = run.use_artifact(wandb_artifact_path)
artifact_dir = artifact.download()
path_to_artifact = f"{artifact_dir}/{os.listdir(artifact_dir)[0]}"
Expand Down Expand Up @@ -132,7 +134,7 @@ def __init__(
if len(os.listdir(self.dashboard_folder)) > 0:
raise ValueError("Dashboard folder not empty. Aborting.")

def get_feature_sparsity_path(self, wandb_artifact_path):
def get_feature_sparsity_path(self, wandb_artifact_path: str):
prefix = wandb_artifact_path.split(":")[0]
return f"{prefix}_log_feature_sparsity:v9"

Expand All @@ -147,11 +149,14 @@ def get_dashboard_folder_name(self):
def init_sae_session(self):
(
self.model,
self.sparse_autoencoder,
sae_group,
self.activation_store,
) = LMSparseAutoencoderSessionloader.load_session_from_pretrained(self.sae_path)
self.sparse_autoencoder = sae_group.autoencoders[0]

def get_tokens(self, n_batches_to_sample_from=2**12, n_prompts_to_select=4096 * 6):
def get_tokens(
self, n_batches_to_sample_from: int = 2**12, n_prompts_to_select: int = 4096 * 6
):
"""
Get the tokens needed for dashboard generation.
"""
Expand All @@ -170,13 +175,16 @@ def get_tokens(self, n_batches_to_sample_from=2**12, n_prompts_to_select=4096 *
return all_tokens[:n_prompts_to_select]

def get_index_to_resume_from(self):
i = 0
for i in range(self.n_features):
if not os.path.exists(f"{self.dashboard_folder}/data_{i:04}.html"):
break

assert self.sparse_autoencoder.cfg.d_sae is not None # keep pyright happy
n_features = self.sparse_autoencoder.cfg.d_sae
n_features_at_a_time = self.n_features_at_a_time
id_of_last_feature_without_dashboard = i
assert self.final_index is not None # keep pyright happy
n_features_remaining = self.final_index - id_of_last_feature_without_dashboard
n_batches_to_do = n_features_remaining // n_features_at_a_time
if self.final_index == n_features:
Expand Down Expand Up @@ -208,8 +216,8 @@ def get_feature_property_df(self):
/ sparse_autoencoder.W_enc.cpu().norm(dim=-1, keepdim=True)
)
d_e_projection = cosine_similarity(W_dec_normalized, W_enc_normalized.T)
b_dec_projection = sparse_autoencoder.b_dec.cpu() @ W_dec_normalized.T

assert sparse_autoencoder.cfg.d_sae is not None # keep pyright happy
temp_df = pd.DataFrame(
{
"log_feature_sparsity": feature_sparsity + 1e-10,
Expand All @@ -231,13 +239,14 @@ def run(self):
Generate the dashboard.
"""

run = None
if self.use_wandb:
# get name from wandb
random_suffix = str(uuid.uuid4())[:8]
name = f"{self.get_dashboard_folder_name()}_{random_suffix}"
run = wandb.init(
project="feature_dashboards",
config=self.sparse_autoencoder.cfg,
config=cast(Any, self.sparse_autoencoder.cfg),
name=name,
tags=[
f"model_{self.sparse_autoencoder.cfg.model_name}",
Expand Down Expand Up @@ -294,6 +303,7 @@ def run(self):
)
wandb.log({"plots/scatter_matrix": wandb.Html(plotly.io.to_html(fig))})

assert self.sparse_autoencoder.cfg.d_sae is not None # keep pyright happy
self.n_features = self.sparse_autoencoder.cfg.d_sae
id_to_start_from = self.get_index_to_resume_from()
id_to_end_at = self.n_features if self.final_index is None else self.final_index
Expand Down Expand Up @@ -351,6 +361,7 @@ def run(self):
artifact.add_file(
f"{self.dashboard_folder}/data_{test_idx:04}.html"
)
assert run is not None # keep pyright happy
run.log_artifact(artifact)

# also upload as html to dashboard
Expand All @@ -369,6 +380,7 @@ def run(self):
# then upload the zip as an artifact
artifact = wandb.Artifact("dashboard", type="zipped_feature_dashboards")
artifact.add_file(f"{self.dashboard_folder}.zip")
assert run is not None # keep pyright happy
run.log_artifact(artifact)

# terminate the run
Expand Down
Loading

0 comments on commit bd5fc43

Please sign in to comment.