This document describes the technical methodology for studying how quantization affects SAE feature interpretability.
Models are loaded at different precisions using HuggingFace transformers with bitsandbytes:
# BF16 (baseline)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
)
# INT8
config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=config,
device_map="auto",
)
# INT4 (NF4)
config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
# FP4
config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="fp4",
bnb_4bit_compute_dtype=torch.bfloat16,
)Register forward hooks on transformer layers to capture hidden states:
def make_hook(layer_idx):
def hook_fn(module, input, output):
if isinstance(output, tuple):
act = output[0][:, -1, :].detach().cpu().float()
else:
act = output[:, -1, :].detach().cpu().float()
layer_activations[layer_idx].append(act)
return hook_fn
# Register hooks
for layer_idx in probe_layers:
layer = model.model.layers[layer_idx]
hook = layer.register_forward_hook(make_hook(layer_idx))
hooks.append(hook)For efficiency, we probe a subset of layers:
- Early: Layer 0 (embedding proximity)
- Evenly spaced: Every 4th layer
- Final: Last 2 layers (output proximity)
Example for 48-layer model: [0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 46, 47]
Prompts span diverse code scenarios:
| Category | Examples | Count |
|---|---|---|
| Python functions | def process_, async def fetch_ |
230 |
| Type annotations | items: List[, result: Dict[str, |
100 |
| Class definitions | class User(BaseModel): |
50 |
| JavaScript/TS | function handle, interface User { |
320 |
| Rust | fn process(, impl Iterator for |
150 |
| Go | func Process(, type Config struct { |
120 |
| Error handling | try:, except Exception as e: |
96 |
Standard sparse autoencoder with ReLU activation:
class SimpleSAE(nn.Module):
def __init__(self, d_input: int, d_hidden: int, l1_coef: float = 1e-3):
super().__init__()
self.encoder = nn.Linear(d_input, d_hidden)
self.decoder = nn.Linear(d_hidden, d_input)
self.l1_coef = l1_coef
def forward(self, x):
z = F.relu(self.encoder(x))
x_hat = self.decoder(z)
return x_hat, z
def loss(self, x, x_hat, z):
mse = F.mse_loss(x_hat, x)
l1 = self.l1_coef * z.abs().mean()
return mse + l1Enforces exact sparsity by keeping only top-k activations:
class TopKSAE(nn.Module):
def __init__(self, d_input: int, d_hidden: int, k: int = 64):
super().__init__()
self.encoder = nn.Linear(d_input, d_hidden)
self.decoder = nn.Linear(d_hidden, d_input)
self.k = k
def forward(self, x):
pre_act = self.encoder(x)
topk_vals, topk_idx = torch.topk(pre_act, self.k, dim=-1)
z = torch.zeros_like(pre_act)
z.scatter_(-1, topk_idx, F.relu(topk_vals))
x_hat = self.decoder(z)
return x_hat, zUses learned gating for sparsity:
class GatedSAE(nn.Module):
def __init__(self, d_input: int, d_hidden: int):
super().__init__()
self.W_gate = nn.Linear(d_input, d_hidden)
self.W_mag = nn.Linear(d_input, d_hidden)
self.decoder = nn.Linear(d_hidden, d_input)
def forward(self, x):
gate = torch.sigmoid(self.W_gate(x))
mag = F.relu(self.W_mag(x))
z = gate * mag
x_hat = self.decoder(z)
return x_hat, z| Parameter | Value | Rationale |
|---|---|---|
| Hidden dim | 8192 | 2x input dim for overcomplete dictionary |
| L1 coefficient | 1e-3 to 5e-4 | Balance sparsity vs reconstruction |
| Optimizer | Adam | Standard choice |
| Learning rate | 1e-3 | With cosine annealing |
| Epochs | 5000-10000 | Early stopping on convergence |
| Batch size | Full dataset | Small enough to fit in memory |
Best-match cosine similarity between encoder weight vectors:
def compute_raw_alignment(W1, W2):
# Normalize
W1_norm = W1 / np.linalg.norm(W1, axis=1, keepdims=True)
W2_norm = W2 / np.linalg.norm(W2, axis=1, keepdims=True)
# Similarity matrix
sim_matrix = W1_norm @ W2_norm.T
# Best match for each feature in W1
best_matches = np.max(sim_matrix, axis=1)
return np.mean(best_matches)Interpretation:
- 1.0: Identical features
- 0.1: Random baseline (high-dimensional)
- Observed: 0.14-0.42
Find optimal rotation to align feature spaces:
from scipy.linalg import orthogonal_procrustes
def compute_procrustes_alignment(W1, W2):
# Normalize
W1_norm = W1 / np.linalg.norm(W1, axis=1, keepdims=True)
W2_norm = W2 / np.linalg.norm(W2, axis=1, keepdims=True)
# Find rotation matrix
R, _ = orthogonal_procrustes(W1_norm, W2_norm)
# Apply rotation
W1_aligned = W1_norm @ R
# Compute alignment
sim_matrix = W1_aligned @ W2_norm.T
best_matches = np.max(sim_matrix, axis=1)
return np.mean(best_matches), RInterpretation:
- Higher than raw alignment indicates rotated but structurally similar spaces
- Observed: 0.43 (vs 0.21 raw)
Stricter metric requiring bidirectional matches:
def compute_mutual_nn(W1, W2):
sim_matrix = W1_norm @ W2_norm.T
nn_1to2 = np.argmax(sim_matrix, axis=1)
nn_2to1 = np.argmax(sim_matrix, axis=0)
mutual_nn = sum(1 for i, j in enumerate(nn_1to2) if nn_2to1[j] == i)
return mutual_nn / len(nn_1to2)Rotation-invariant similarity of representations:
def compute_cka(X1, X2):
# Center
X1 = X1 - X1.mean(axis=0)
X2 = X2 - X2.mean(axis=0)
# Gram matrices
K1 = X1 @ X1.T
K2 = X2 @ X2.T
# HSIC
hsic_12 = np.sum(K1 * K2)
hsic_11 = np.sum(K1 * K1)
hsic_22 = np.sum(K2 * K2)
return hsic_12 / np.sqrt(hsic_11 * hsic_22)Interpretation:
- 1.0: Identical representations (up to rotation)
- 0.0: Orthogonal representations
- Observed: 0.89 (high similarity)
A neuron is "dead" if its maximum activation across all samples is below threshold:
def count_dead_neurons(z, threshold=1e-6):
max_activations = z.max(dim=0).values
return (max_activations < threshold).sum().item()| Category | Criterion | Interpretation |
|---|---|---|
| Dead | max_act < 1e-6 | Never activates |
| Rare | freq < 1% | Highly selective |
| Common | 1% ≤ freq < 50% | Moderately active |
| Frequent | freq ≥ 50% | General feature |
Dead neurons arise from:
- Insufficient data: Small sample size doesn't activate all features
- L1 penalty: Pushes weights toward zero
- ReLU saturation: Negative pre-activations stay at zero
- Encoder-decoder imbalance: Some directions not useful for reconstruction
def bootstrap_ci(data, n_bootstrap=1000, ci=0.95):
boot_means = []
for _ in range(n_bootstrap):
sample = np.random.choice(data, size=len(data), replace=True)
boot_means.append(np.mean(sample))
alpha = 1 - ci
lower = np.percentile(boot_means, 100 * alpha / 2)
upper = np.percentile(boot_means, 100 * (1 - alpha / 2))
return {"mean": np.mean(data), "ci_lower": lower, "ci_upper": upper}def paired_permutation_test(x, y, n_permutations=10000):
observed_diff = np.mean(x) - np.mean(y)
perm_diffs = []
combined = np.column_stack([x, y])
for _ in range(n_permutations):
swaps = np.random.randint(0, 2, size=len(x))
perm_x = np.where(swaps == 0, combined[:, 0], combined[:, 1])
perm_y = np.where(swaps == 0, combined[:, 1], combined[:, 0])
perm_diffs.append(np.mean(perm_x) - np.mean(perm_y))
p_value = np.mean(np.abs(perm_diffs) >= np.abs(observed_diff))
return p_valueFor alignment significance testing:
def simulate_random_alignment(n_features=8192, d_input=4096, n_simulations=100):
null_sims = []
for _ in range(n_simulations):
W1 = np.random.randn(n_features, d_input)
W2 = np.random.randn(n_features, d_input)
W1 = W1 / np.linalg.norm(W1, axis=1, keepdims=True)
W2 = W2 / np.linalg.norm(W2, axis=1, keepdims=True)
sim_matrix = W1 @ W2.T
best_matches = np.max(sim_matrix, axis=1)
null_sims.append(np.mean(best_matches))
return np.mean(null_sims), np.std(null_sims)Python 3.11
PyTorch 2.1+
transformers 4.40+
bitsandbytes 0.43+
scipy 1.11+
numpy 1.24+
- GPU: NVIDIA H100 80GB HBM3
- VRAM requirements:
- BF16: ~60GB for 30B model
- INT8: ~35GB
- INT4: ~20GB
np.random.seed(42)
torch.manual_seed(42)- Sample size: 227-1446 prompts may not capture full feature diversity
- SAE architecture: Only tested 3 variants, others may behave differently
- Models tested: 2 models (MoE and dense), more needed for generality
- Quantization methods: Only tested bitsandbytes, not GPTQ/AWQ
- Downstream tasks: No task-specific evaluation yet
- Larger prompt sets: 10K+ diverse prompts
- Expert-level analysis: For MoE, analyze per-expert features
- Activation patching: Test if aligned features have same causal effects
- More quantization methods: GPTQ, AWQ, SmoothQuant
- Downstream benchmarks: HumanEval, MMLU, etc.