-
Notifications
You must be signed in to change notification settings - Fork 225
Fix explained_variance computing variance relative to zero instead of mean #665
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
90bd9e4
e66355f
56066d8
032ce87
47982ed
c4fe516
accaaf4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -644,6 +644,38 @@ def test_get_sparsity_and_variance_metrics_identity_sae_perfect_reconstruction( | |
| assert metrics["mse"] == pytest.approx(0.0, abs=1e-5) | ||
|
|
||
|
|
||
| def test_explained_variance_uses_mean_centered_variance(): | ||
|
||
| """Verify explained_variance computes Var(X) = E[||X||^2] - ||E[X]||^2, not E[||X||^2].""" | ||
| # Construct inputs with a large mean so the difference between | ||
| # variance-from-zero and variance-from-mean is significant. | ||
| d_model = 8 | ||
| n_samples = 1000 | ||
| mean = torch.full((d_model,), 10.0) | ||
| x = mean + torch.randn(n_samples, d_model) * 0.5 | ||
|
|
||
| # Ground truth total variance: sum of per-dimension variances | ||
| expected_total_var = x.var(dim=0, correction=0).sum().item() | ||
|
|
||
| # The formula used in evals.py after the fix: | ||
| # total_variance = E[||X||^2] - ||E[X]||^2 | ||
| mean_sum_of_squares = x.pow(2).sum(dim=-1).mean(dim=0) | ||
| mean_act_per_dimension = x.mean(dim=0) | ||
| computed_total_var = ( | ||
| mean_sum_of_squares - (mean_act_per_dimension**2).sum() | ||
| ).item() | ||
|
|
||
| assert computed_total_var == pytest.approx(expected_total_var, rel=1e-3) | ||
|
|
||
| # With the bug (.pow(2) on the mean term), the subtracted term captures E[x^2]^2 | ||
| # instead of E[x]^2, making total_variance much larger than the true variance | ||
| # for data with a large mean. | ||
| buggy_mean_act = x.pow(2).mean(dim=0) # bug: .pow(2) before mean | ||
| buggy_total_var = ( | ||
| mean_sum_of_squares - (buggy_mean_act**2).sum() | ||
| ).item() | ||
|
||
| assert abs(buggy_total_var - expected_total_var) > expected_total_var * 0.5 | ||
|
|
||
|
|
||
| def test_process_args(): | ||
| args = [ | ||
| "gpt2-small-res_scefr-ajt", | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
total_variancecan legitimately be 0 (e.g., if only one unmasked token remains afterignore_tokens, or if activations are constant). In that case1 - residual_variance / total_variancewill produce inf/NaN. Consider guarding fortotal_variance <= eps(and possibly small negative values from fp roundoff) similarly tosae_lens/synthetic/evals.py, returning 1.0 when both variances are ~0, else 0.0 (or another defined fallback).