Skip to content
6 changes: 3 additions & 3 deletions sae_lens/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def get_sparsity_and_variance_metrics(
(flattened_sae_input).pow(2).sum(dim=-1).mean(dim=0) # scalar
)
mean_act_per_dimension.append(
(flattened_sae_input).pow(2).mean(dim=0) # [d_model]
(flattened_sae_input).mean(dim=0) # [d_model]
)
mean_sum_of_resid_squared.append(
resid_sum_of_squares.mean(dim=0) # scalar
Expand Down Expand Up @@ -582,8 +582,8 @@ def get_sparsity_and_variance_metrics(
# calculate explained variance
if compute_variance_metrics:
mean_sum_of_squares = torch.stack(mean_sum_of_squares).mean(dim=0)
mean_act_per_dimension = torch.cat(mean_act_per_dimension).mean(dim=0)
total_variance = mean_sum_of_squares - mean_act_per_dimension**2
mean_act_per_dimension = torch.stack(mean_act_per_dimension).mean(dim=0)
total_variance = mean_sum_of_squares - (mean_act_per_dimension**2).sum()
residual_variance = torch.stack(mean_sum_of_resid_squared).mean(dim=0)
metrics["explained_variance"] = (1 - residual_variance / total_variance).item()
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

total_variance can legitimately be 0 (e.g., if only one unmasked token remains after ignore_tokens, or if activations are constant). In that case 1 - residual_variance / total_variance will produce inf/NaN. Consider guarding for total_variance <= eps (and possibly small negative values from fp roundoff) similarly to sae_lens/synthetic/evals.py, returning 1.0 when both variances are ~0, else 0.0 (or another defined fallback).

Suggested change
metrics["explained_variance"] = (1 - residual_variance / total_variance).item()
# Guard against zero / near-zero total variance to avoid inf/NaN.
# When both variances are ~0, treat explained variance as 1.0
# (perfect reconstruction of a constant signal); otherwise 0.0.
eps = 1e-12
if torch.abs(total_variance) <= eps:
if torch.abs(residual_variance) <= eps:
explained_variance = torch.tensor(1.0, device=total_variance.device)
else:
explained_variance = torch.tensor(0.0, device=total_variance.device)
else:
explained_variance = 1 - residual_variance / total_variance
metrics["explained_variance"] = explained_variance.item()

Copilot uses AI. Check for mistakes.

Expand Down
32 changes: 32 additions & 0 deletions tests/test_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,38 @@ def test_get_sparsity_and_variance_metrics_identity_sae_perfect_reconstruction(
assert metrics["mse"] == pytest.approx(0.0, abs=1e-5)


def test_explained_variance_uses_mean_centered_variance():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test doesn't test any actual code

"""Verify explained_variance computes Var(X) = E[||X||^2] - ||E[X]||^2, not E[||X||^2]."""
# Construct inputs with a large mean so the difference between
# variance-from-zero and variance-from-mean is significant.
d_model = 8
n_samples = 1000
mean = torch.full((d_model,), 10.0)
x = mean + torch.randn(n_samples, d_model) * 0.5

# Ground truth total variance: sum of per-dimension variances
expected_total_var = x.var(dim=0, correction=0).sum().item()

# The formula used in evals.py after the fix:
# total_variance = E[||X||^2] - ||E[X]||^2
mean_sum_of_squares = x.pow(2).sum(dim=-1).mean(dim=0)
mean_act_per_dimension = x.mean(dim=0)
computed_total_var = (
mean_sum_of_squares - (mean_act_per_dimension**2).sum()
).item()

assert computed_total_var == pytest.approx(expected_total_var, rel=1e-3)

# With the bug (.pow(2) on the mean term), the subtracted term captures E[x^2]^2
# instead of E[x]^2, making total_variance much larger than the true variance
# for data with a large mean.
buggy_mean_act = x.pow(2).mean(dim=0) # bug: .pow(2) before mean
buggy_total_var = (
mean_sum_of_squares - (buggy_mean_act**2).sum()
).item()
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The explanatory comment says the buggy formula makes total_variance “much larger than the true variance”, but with the specific buggy_mean_act = x.pow(2).mean(dim=0) used below the resulting buggy_total_var will typically be very negative (because you subtract (\sum_d E[x_d^2]^2)). Consider rewording to avoid misleading readers (and optionally clarify that the inflated explained_variance in the original bug required both the .pow(2) mistake and the cat->scalar collapse).

Copilot uses AI. Check for mistakes.
assert abs(buggy_total_var - expected_total_var) > expected_total_var * 0.5


def test_process_args():
args = [
"gpt2-small-res_scefr-ajt",
Expand Down
Loading