Skip to content
16 changes: 12 additions & 4 deletions sae_lens/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def get_sparsity_and_variance_metrics(
(flattened_sae_input).pow(2).sum(dim=-1).mean(dim=0) # scalar
)
mean_act_per_dimension.append(
(flattened_sae_input).pow(2).mean(dim=0) # [d_model]
(flattened_sae_input).mean(dim=0) # [d_model]
)
mean_sum_of_resid_squared.append(
resid_sum_of_squares.mean(dim=0) # scalar
Expand Down Expand Up @@ -582,10 +582,18 @@ def get_sparsity_and_variance_metrics(
# calculate explained variance
if compute_variance_metrics:
mean_sum_of_squares = torch.stack(mean_sum_of_squares).mean(dim=0)
mean_act_per_dimension = torch.cat(mean_act_per_dimension).mean(dim=0)
total_variance = mean_sum_of_squares - mean_act_per_dimension**2
mean_act_per_dimension = torch.stack(mean_act_per_dimension).mean(dim=0)
total_variance = mean_sum_of_squares - (mean_act_per_dimension**2).sum()
residual_variance = torch.stack(mean_sum_of_resid_squared).mean(dim=0)
metrics["explained_variance"] = (1 - residual_variance / total_variance).item()
eps = 1e-12
if torch.abs(total_variance) <= eps:
if torch.abs(residual_variance) <= eps:
explained_variance = torch.tensor(1.0, device=total_variance.device)
else:
explained_variance = torch.tensor(0.0, device=total_variance.device)
else:
explained_variance = 1 - residual_variance / total_variance
metrics["explained_variance"] = explained_variance.item()

# Aggregate feature-wise metrics
feature_metrics: dict[str, list[float]] = {}
Expand Down
47 changes: 47 additions & 0 deletions tests/test_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,53 @@ def test_get_sparsity_and_variance_metrics_identity_sae_perfect_reconstruction(
assert metrics["mse"] == pytest.approx(0.0, abs=1e-5)


def test_explained_variance_uses_mean_centered_variance():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test doesn't test any actual code

"""Verify explained_variance computes Var(X) = E[||X||^2] - ||E[X]||^2, not E[||X||^2]."""
# Construct inputs with a large mean so the difference between
# variance-from-zero and variance-from-mean is significant.
d_model = 8
n_samples = 1000
mean = torch.full((d_model,), 10.0)
x = mean + torch.randn(n_samples, d_model) * 0.5

# Ground truth total variance: sum of per-dimension variances
expected_total_var = x.var(dim=0, correction=0).sum().item()

# The formula used in evals.py after the fix:
# total_variance = E[||X||^2] - ||E[X]||^2
mean_sum_of_squares = x.pow(2).sum(dim=-1).mean(dim=0)
mean_act_per_dimension = x.mean(dim=0)
computed_total_var = (
mean_sum_of_squares - (mean_act_per_dimension**2).sum()
).item()

assert computed_total_var == pytest.approx(expected_total_var, rel=1e-3)

# With the bug (.pow(2) on the mean term), the subtracted term becomes
# sum(E[x_d^2]^2) instead of sum(E[x_d]^2). For large-mean data this
# makes buggy_total_var very negative (or wildly wrong in general),
# which distorts the explained_variance ratio.
buggy_mean_act = x.pow(2).mean(dim=0) # bug: .pow(2) before mean
buggy_total_var = (mean_sum_of_squares - (buggy_mean_act**2).sum()).item()
assert abs(buggy_total_var - expected_total_var) > expected_total_var * 0.5


def test_explained_variance_zero_total_variance():
"""When activations are constant (zero variance), explained_variance should be 1.0
for perfect reconstruction and 0.0 when residual is nonzero."""
eps = 1e-12

# Case 1: constant activations, perfect reconstruction -> 1.0
residual_var = torch.tensor(0.0)
ev = 1.0 if torch.abs(residual_var) <= eps else 0.0
assert ev == 1.0

# Case 2: constant activations, nonzero residual -> 0.0
residual_var = torch.tensor(0.5)
ev = 1.0 if torch.abs(residual_var) <= eps else 0.0
assert ev == 0.0


def test_process_args():
args = [
"gpt2-small-res_scefr-ajt",
Expand Down
Loading