Skip to content

Commit

Permalink
chore: Replace isort black and flake8 with ruff (#393)
Browse files Browse the repository at this point in the history
* replaces in cache_activations_runner.py

* replaces isort, black, adn flake8 with Ruff

* adds SIM lint rule

* fixes for CI check

* adds RET lint rule

* adds LOG lint rule

* fixes RET error

* resolves conflicts

* applies make format

* adds T20 rule

* replaces extend-select with select

* resolves conflicts

* fixes lint errors

* update .vscode/settings.json

* Revert "update .vscode/settings.json"

This reverts commit 1bb5497.

* updates .vscode/settings.json

* adds newline
  • Loading branch information
anthonyduong9 authored Dec 3, 2024
1 parent d027158 commit 52dbff9
Show file tree
Hide file tree
Showing 54 changed files with 273 additions and 383 deletions.
10 changes: 4 additions & 6 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,10 @@ jobs:
- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction
- name: Lint with flake8
run: poetry run flake8 .
- name: black code formatting
run: poetry run black . --check
- name: isort linting
run: poetry run isort . --check-only --diff
- name: Check linting
run: poetry run ruff check .
- name: Check formatting
run: poetry run ruff format --check .
- name: type checking
run: poetry run pyright
- name: Run Unit Tests
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ venv.bak/
.dmypy.json
dmypy.json

# ruff
.ruff_cache

# Pyre type checker
.pyre/

Expand Down
16 changes: 4 additions & 12 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,8 @@ repos:
- id: check-yaml
- id: check-added-large-files
args: [--maxkb=250000]
- repo: https://github.com/psf/black
rev: 24.3.0
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.4
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
hooks:
- id: flake8
- id: ruff
args: ["--fix"]
8 changes: 5 additions & 3 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
"python.testing.pytestEnabled": true,

"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.defaultFormatter": "charliermarsh.ruff",
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.fixAll": "explicit",
"source.organizeImports": "explicit"
}
},
"isort.args": ["--profile", "black"],
"editor.defaultFormatter": "mikoz.black-py",

"notebook.formatOnSave.enabled": true,
"notebook.codeActionsOnSave": {},
"liveServer.settings.port": 5501
}
2 changes: 1 addition & 1 deletion docs/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ make check-ci # validate the install

## Testing, Linting, and Formatting

This project uses [pytest](https://docs.pytest.org/en/stable/) for testing, [flake8](https://flake8.pycqa.org/en/latest/) for linting, [pyright](https://github.com/microsoft/pyright) for type-checking, and [black](https://black.readthedocs.io/en/stable/) and [isort](https://pycqa.github.io/isort/) for formatting.
This project uses [pytest](https://docs.pytest.org/en/stable/) for testing, [pyright](https://github.com/microsoft/pyright) for type-checking, and [Ruff](https://docs.astral.sh/ruff/) for formatting and linting.

If you add new code, it would be greatly appreciated if you could add tests in the `tests/unit` directory. You can run the tests with:

Expand Down
7 changes: 4 additions & 3 deletions docs/generate_sae_table.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# type: ignore
# ruff: noqa: T201
from pathlib import Path
from textwrap import dedent

Expand Down Expand Up @@ -26,7 +27,7 @@
]


def on_pre_build(config):
def on_pre_build(config): # noqa: ARG001
print("Generating SAE table...")
generate_sae_table()
print("SAE table generation complete.")
Expand All @@ -35,7 +36,7 @@ def on_pre_build(config):
def generate_sae_table():
# Read the YAML file
yaml_path = Path("sae_lens/pretrained_saes.yaml")
with open(yaml_path, "r") as file:
with open(yaml_path) as file:
data = yaml.safe_load(file)

# Start the Markdown content
Expand Down Expand Up @@ -68,7 +69,7 @@ def generate_sae_table():
cfg = handle_config_defaulting(cfg)
cfg = SAEConfig.from_dict(cfg).to_dict()

if "neuronpedia" not in info.keys():
if "neuronpedia" not in info:
info["neuronpedia"] = None

info.update(cfg)
Expand Down
9 changes: 4 additions & 5 deletions makefile
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
format:
poetry run black .
poetry run isort .
poetry run ruff format .
poetry run ruff check --fix-only .

check-format:
poetry run flake8 .
poetry run black --check .
poetry run isort --check-only --diff .
poetry run ruff check .
poetry run ruff format --check .

check-type:
poetry run pyright .
Expand Down
16 changes: 8 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,10 @@ zstandard = "^0.22.0"
typing-extensions = "^4.10.0"
simple-parsing = "^0.1.6"


[tool.poetry.group.dev.dependencies]
black = { version = "24.4.0", extras = ["jupyter"] }
pytest = "^8.0.2"
pytest-cov = "^4.1.0"
pre-commit = "^3.6.2"
flake8 = "7.0.0"
isort = "5.13.2"
pyright = "1.1.365"
mamba-lens = "^0.0.4"
ansible-lint = { version = "^24.2.3", markers = "platform_system != 'Windows'" }
Expand All @@ -61,14 +57,19 @@ mkdocs-section-index = "^0.3.9"
mkdocstrings = "^0.25.2"
mkdocstrings-python = "^1.10.9"
tabulate = "^0.9.0"
ruff = "^0.7.4"

[tool.poetry.extras]
mamba = ["mamba-lens"]

[tool.ruff.lint]
exclude = ["*.ipynb"]
ignore = ["E203", "E501", "E731", "F722", "E741", "F821", "F403", "ARG002"]
select = ["UP", "TID", "I", "F", "E", "ARG", "SIM", "RET", "LOG", "T20"]

[tool.isort]
profile = "black"
src_paths = ["sae_lens", "tests"]
[tool.ruff.lint.per-file-ignores]
"tests/benchmark/*" = ["T20"]
"scripts/*" = ["T20"]

[tool.pyright]
typeCheckingMode = "strict"
Expand All @@ -92,7 +93,6 @@ ignore = [
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"


[tool.semantic_release]
version_variables = [
"sae_lens/__init__.py:__version__",
Expand Down
1 change: 1 addition & 0 deletions sae_lens/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# ruff: noqa: E402
__version__ = "5.1.0"

import logging
Expand Down
11 changes: 3 additions & 8 deletions sae_lens/analysis/feature_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_feature_property_df(sae: SAE, feature_sparsity: torch.Tensor):
d_e_projection = (W_dec_normalized * W_enc_normalized).sum(-1)
b_dec_projection = sae.b_dec.cpu() @ W_dec_normalized.T

temp_df = pd.DataFrame(
return pd.DataFrame(
{
"log_feature_sparsity": feature_sparsity + 1e-10,
"d_e_projection": d_e_projection,
Expand All @@ -32,8 +32,6 @@ def get_feature_property_df(sae: SAE, feature_sparsity: torch.Tensor):
}
)

return temp_df


@torch.no_grad()
def get_stats_df(projection: torch.Tensor):
Expand All @@ -48,7 +46,7 @@ def get_stats_df(projection: torch.Tensor):
skews = torch.mean(torch.pow(zscores, 3.0), dim=1)
kurtosis = torch.mean(torch.pow(zscores, 4.0), dim=1)

stats_df = pd.DataFrame(
return pd.DataFrame(
{
"feature": range(len(skews)),
"mean": mean.numpy().squeeze(),
Expand All @@ -58,8 +56,6 @@ def get_stats_df(projection: torch.Tensor):
}
)

return stats_df


@torch.no_grad()
def get_all_stats_dfs(
Expand All @@ -82,8 +78,7 @@ def get_all_stats_dfs(
W_U_stats_df_dec["layer"] = layer + (1 if "post" in key else 0)
stats_dfs.append(W_U_stats_df_dec)

W_U_stats_df_dec_all_layers = pd.concat(stats_dfs, axis=0)
return W_U_stats_df_dec_all_layers
return pd.concat(stats_dfs, axis=0)


@torch.no_grad()
Expand Down
11 changes: 2 additions & 9 deletions sae_lens/analysis/hooked_sae_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@ def get_deep_attr(obj: Any, path: str):
parts = path.split(".")
# Navigate to the last component in the path
for part in parts:
if part.isdigit(): # This is a list index
obj = obj[int(part)]
else: # This is an attribute
obj = getattr(obj, part)
obj = obj[int(part)] if part.isdigit() else getattr(obj, part)
return obj


Expand All @@ -48,16 +45,12 @@ def set_deep_attr(obj: Any, path: str, value: Any):
parts = path.split(".")
# Navigate to the last component in the path
for part in parts[:-1]:
if part.isdigit(): # This is a list index
obj = obj[int(part)]
else: # This is an attribute
obj = getattr(obj, part)
obj = obj[int(part)] if part.isdigit() else getattr(obj, part)
# Set the value on the final attribute
setattr(obj, parts[-1], value)


class HookedSAETransformer(HookedTransformer):

def __init__(
self,
*model_args: Any,
Expand Down
19 changes: 7 additions & 12 deletions sae_lens/analysis/neuronpedia_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def NanAndInfReplacer(value: str):
if value in replacements:
replaced_value = replacements[value]
return float(replaced_value)
else:
return NAN_REPLACEMENT
return NAN_REPLACEMENT


def open_neuronpedia_feature_dashboard(sae: SAE, index: int):
Expand All @@ -75,7 +74,6 @@ def get_neuronpedia_quick_list(
features: list[int],
name: str = "temporary_list",
):

sae_id = sae.cfg.neuronpedia_id
if sae_id is None:
logger.warning(
Expand Down Expand Up @@ -146,10 +144,7 @@ def has_activating_text(self) -> bool:
"""Check if the feature has activating text."""
if self.activations is None:
return False
else:
return any(
max(activation.act_values) > 0 for activation in self.activations
)
return any(max(activation.act_values) > 0 for activation in self.activations)


T = TypeVar("T")
Expand Down Expand Up @@ -209,8 +204,7 @@ def make_neuronpedia_list_with_features(
if "url" in result and open_browser:
webbrowser.open(result["url"])
return result["url"]
else:
raise Exception("Error in creating list: " + result["message"])
raise Exception("Error in creating list: " + result["message"])


def test_key(api_key: str):
Expand Down Expand Up @@ -265,8 +259,7 @@ async def autointerp_neuronpedia_features( # noqa: C901
raise Exception(
"You need to provide an OpenAI API key either in environment variable OPENAI_API_KEY or as an argument."
)
else:
os.environ["OPENAI_API_KEY"] = openai_api_key
os.environ["OPENAI_API_KEY"] = openai_api_key

if autointerp_explainer_model_name not in HARMONY_V4_MODELS:
raise Exception(
Expand Down Expand Up @@ -452,7 +445,9 @@ async def autointerp_neuronpedia_features( # noqa: C901
if do_score and autointerp_scorer_model_name and scored_simulation:
feature_data["activations"] = feature.activations
feature_data["simulationModel"] = autointerp_scorer_model_name
feature_data["simulationActivations"] = scored_simulation.scored_sequence_simulations # type: ignore
feature_data["simulationActivations"] = (
scored_simulation.scored_sequence_simulations
) # type: ignore
feature_data["simulationScore"] = feature.autointerp_explanation_score
feature_data_str = json.dumps(feature_data, default=vars)

Expand Down
Loading

0 comments on commit 52dbff9

Please sign in to comment.