Skip to content

Commit

Permalink
Merge pull request #22 from themachinefan/faster_geometric_median
Browse files Browse the repository at this point in the history
Faster geometric median.
  • Loading branch information
jbloomAus authored Mar 19, 2024
2 parents eb90cc9 + 736bf83 commit 341c49a
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ repos:
- id: check-added-large-files
args: [--maxkb=250000]
- repo: https://github.com/psf/black
rev: 24.2.0
rev: 24.3.0
hooks:
- id: black
- repo: https://github.com/PyCQA/flake8
Expand Down
2 changes: 2 additions & 0 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ check-format:
poetry run black --check .
poetry run isort --check-only --diff .

check-type:
poetry run pyright .

test:
make unit-test
Expand Down
126 changes: 126 additions & 0 deletions sae_training/geometric_median.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from types import SimpleNamespace
from typing import Optional

import torch
import tqdm


def weighted_average(points: torch.Tensor, weights: torch.Tensor):
weights = weights / weights.sum()
return (points * weights.view(-1, 1)).sum(dim=0)


@torch.no_grad()
def geometric_median_objective(
median: torch.Tensor, points: torch.Tensor, weights: torch.Tensor
) -> torch.Tensor:

norms = torch.linalg.norm(points - median.view(1, -1), dim=1)

return (norms * weights).sum()


def compute_geometric_median(
points: torch.Tensor,
weights: Optional[torch.Tensor] = None,
eps: float = 1e-6,
maxiter: int = 100,
ftol: float = 1e-20,
do_log: bool = False,
):
"""
:param points: ``torch.Tensor`` of shape ``(n, d)``
:param weights: Optional ``torch.Tensor`` of shape :math:``(n,)``.
:param eps: Smallest allowed value of denominator, to avoid divide by zero.
Equivalently, this is a smoothing parameter. Default 1e-6.
:param maxiter: Maximum number of Weiszfeld iterations. Default 100
:param ftol: If objective value does not improve by at least this `ftol` fraction, terminate the algorithm. Default 1e-20.
:param do_log: If true will return a log of function values encountered through the course of the algorithm
:return: SimpleNamespace object with fields
- `median`: estimate of the geometric median, which is a ``torch.Tensor`` object of shape :math:``(d,)``
- `termination`: string explaining how the algorithm terminated.
- `logs`: function values encountered through the course of the algorithm in a list (None if do_log is false).
"""
with torch.no_grad():

if weights is None:
weights = torch.ones((points.shape[0],), device=points.device)
# initialize median estimate at mean
new_weights = weights
median = weighted_average(points, weights)
objective_value = geometric_median_objective(median, points, weights)
if do_log:
logs = [objective_value]
else:
logs = None

# Weiszfeld iterations
early_termination = False
pbar = tqdm.tqdm(range(maxiter))
for _ in pbar:
prev_obj_value = objective_value

norms = torch.linalg.norm(points - median.view(1, -1), dim=1)
new_weights = weights / torch.clamp(norms, min=eps)
median = weighted_average(points, new_weights)
objective_value = geometric_median_objective(median, points, weights)

if logs is not None:
logs.append(objective_value)
if abs(prev_obj_value - objective_value) <= ftol * objective_value:
early_termination = True
break

pbar.set_description(f"Objective value: {objective_value:.4f}")

median = weighted_average(points, new_weights) # allow autodiff to track it
return SimpleNamespace(
median=median,
new_weights=new_weights,
termination=(
"function value converged within tolerance"
if early_termination
else "maximum iterations reached"
),
logs=logs,
)


if __name__ == "__main__":
import time

from sae_training.geom_median.src.geom_median.torch import (
compute_geometric_median as original_compute_geometric_median,
)

TOLERANCE = 1e-2

dim1 = 10000
dim2 = 768
device = "cuda" if torch.cuda.is_available() else "cpu"

sample = (
torch.randn((dim1, dim2), device=device) * 100
) # seems to be the order of magnitude of the actual use case
weights = torch.randn((dim1,), device=device)

torch.tensor(weights, device=device)

tic = time.perf_counter()
new = compute_geometric_median(sample, weights=weights, maxiter=100)
print(f"new code takes {time.perf_counter()-tic} seconds!")
tic = time.perf_counter()
old = original_compute_geometric_median(
sample, weights=weights, skip_typechecks=True, maxiter=100, per_component=False
)
print(f"old code takes {time.perf_counter()-tic} seconds!")

print(f"max diff in median {torch.max(torch.abs(new.median - old.median))}")
print(
f"max diff in weights {torch.max(torch.abs(new.new_weights - old.new_weights))}"
)

assert torch.allclose(new.median, old.median, atol=TOLERANCE), "Median diverges!"
assert torch.allclose(
new.new_weights, old.new_weights, atol=TOLERANCE
), "Weights diverges!"
10 changes: 7 additions & 3 deletions sae_training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import wandb
from sae_training.activations_store import ActivationsStore
from sae_training.evals import run_evals
from sae_training.geom_median.src.geom_median.torch import compute_geometric_median
from sae_training.geometric_median import compute_geometric_median
from sae_training.optim import get_scheduler
from sae_training.sae_group import SAEGroup

Expand Down Expand Up @@ -80,16 +80,20 @@ def train_sae_on_language_model(
for sae in sae_group:
hyperparams = sae.cfg
sae_layer_id = all_layers.index(hyperparams.hook_point_layer)
layer_acts = activation_store.storage_buffer.detach().cpu()[:, sae_layer_id, :]
if hyperparams.b_dec_init_method == "geometric_median":
layer_acts = activation_store.storage_buffer.detach()[:, sae_layer_id, :]
# get geometric median of the activations if we're using those.
if sae_layer_id not in geometric_medians:
median = compute_geometric_median(
layer_acts, skip_typechecks=True, maxiter=100, per_component=False
layer_acts,
maxiter=100,
).median
geometric_medians[sae_layer_id].append(median)
sae.initialize_b_dec_with_precalculated(geometric_medians[sae_layer_id])
elif hyperparams.b_dec_init_method == "mean":
layer_acts = activation_store.storage_buffer.detach().cpu()[
:, sae_layer_id, :
]
sae.initialize_b_dec_with_mean(layer_acts)
sae.train()

Expand Down

0 comments on commit 341c49a

Please sign in to comment.