Skip to content

Latest commit

 

History

History
407 lines (300 loc) · 10.3 KB

File metadata and controls

407 lines (300 loc) · 10.3 KB

Methodology: Quantization × Interpretability

Overview

This document describes the technical methodology for studying how quantization affects SAE feature interpretability.


1. Activation Extraction

1.1 Model Loading

Models are loaded at different precisions using HuggingFace transformers with bitsandbytes:

# BF16 (baseline)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# INT8
config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=config,
    device_map="auto",
)

# INT4 (NF4)
config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

# FP4
config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="fp4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

1.2 Hook-based Extraction

Register forward hooks on transformer layers to capture hidden states:

def make_hook(layer_idx):
    def hook_fn(module, input, output):
        if isinstance(output, tuple):
            act = output[0][:, -1, :].detach().cpu().float()
        else:
            act = output[:, -1, :].detach().cpu().float()
        layer_activations[layer_idx].append(act)
    return hook_fn

# Register hooks
for layer_idx in probe_layers:
    layer = model.model.layers[layer_idx]
    hook = layer.register_forward_hook(make_hook(layer_idx))
    hooks.append(hook)

1.3 Layer Selection

For efficiency, we probe a subset of layers:

  • Early: Layer 0 (embedding proximity)
  • Evenly spaced: Every 4th layer
  • Final: Last 2 layers (output proximity)

Example for 48-layer model: [0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 46, 47]

1.4 Prompt Design

Prompts span diverse code scenarios:

Category Examples Count
Python functions def process_, async def fetch_ 230
Type annotations items: List[, result: Dict[str, 100
Class definitions class User(BaseModel): 50
JavaScript/TS function handle, interface User { 320
Rust fn process(, impl Iterator for 150
Go func Process(, type Config struct { 120
Error handling try:, except Exception as e: 96

2. SAE Architecture

2.1 Simple SAE

Standard sparse autoencoder with ReLU activation:

class SimpleSAE(nn.Module):
    def __init__(self, d_input: int, d_hidden: int, l1_coef: float = 1e-3):
        super().__init__()
        self.encoder = nn.Linear(d_input, d_hidden)
        self.decoder = nn.Linear(d_hidden, d_input)
        self.l1_coef = l1_coef

    def forward(self, x):
        z = F.relu(self.encoder(x))
        x_hat = self.decoder(z)
        return x_hat, z

    def loss(self, x, x_hat, z):
        mse = F.mse_loss(x_hat, x)
        l1 = self.l1_coef * z.abs().mean()
        return mse + l1

2.2 TopK SAE

Enforces exact sparsity by keeping only top-k activations:

class TopKSAE(nn.Module):
    def __init__(self, d_input: int, d_hidden: int, k: int = 64):
        super().__init__()
        self.encoder = nn.Linear(d_input, d_hidden)
        self.decoder = nn.Linear(d_hidden, d_input)
        self.k = k

    def forward(self, x):
        pre_act = self.encoder(x)
        topk_vals, topk_idx = torch.topk(pre_act, self.k, dim=-1)
        z = torch.zeros_like(pre_act)
        z.scatter_(-1, topk_idx, F.relu(topk_vals))
        x_hat = self.decoder(z)
        return x_hat, z

2.3 Gated SAE

Uses learned gating for sparsity:

class GatedSAE(nn.Module):
    def __init__(self, d_input: int, d_hidden: int):
        super().__init__()
        self.W_gate = nn.Linear(d_input, d_hidden)
        self.W_mag = nn.Linear(d_input, d_hidden)
        self.decoder = nn.Linear(d_hidden, d_input)

    def forward(self, x):
        gate = torch.sigmoid(self.W_gate(x))
        mag = F.relu(self.W_mag(x))
        z = gate * mag
        x_hat = self.decoder(z)
        return x_hat, z

2.4 Training Configuration

Parameter Value Rationale
Hidden dim 8192 2x input dim for overcomplete dictionary
L1 coefficient 1e-3 to 5e-4 Balance sparsity vs reconstruction
Optimizer Adam Standard choice
Learning rate 1e-3 With cosine annealing
Epochs 5000-10000 Early stopping on convergence
Batch size Full dataset Small enough to fit in memory

3. Alignment Metrics

3.1 Raw Cosine Alignment

Best-match cosine similarity between encoder weight vectors:

def compute_raw_alignment(W1, W2):
    # Normalize
    W1_norm = W1 / np.linalg.norm(W1, axis=1, keepdims=True)
    W2_norm = W2 / np.linalg.norm(W2, axis=1, keepdims=True)

    # Similarity matrix
    sim_matrix = W1_norm @ W2_norm.T

    # Best match for each feature in W1
    best_matches = np.max(sim_matrix, axis=1)

    return np.mean(best_matches)

Interpretation:

  • 1.0: Identical features
  • 0.1: Random baseline (high-dimensional)
  • Observed: 0.14-0.42

3.2 Procrustes Alignment

Find optimal rotation to align feature spaces:

from scipy.linalg import orthogonal_procrustes

def compute_procrustes_alignment(W1, W2):
    # Normalize
    W1_norm = W1 / np.linalg.norm(W1, axis=1, keepdims=True)
    W2_norm = W2 / np.linalg.norm(W2, axis=1, keepdims=True)

    # Find rotation matrix
    R, _ = orthogonal_procrustes(W1_norm, W2_norm)

    # Apply rotation
    W1_aligned = W1_norm @ R

    # Compute alignment
    sim_matrix = W1_aligned @ W2_norm.T
    best_matches = np.max(sim_matrix, axis=1)

    return np.mean(best_matches), R

Interpretation:

  • Higher than raw alignment indicates rotated but structurally similar spaces
  • Observed: 0.43 (vs 0.21 raw)

3.3 Mutual Nearest Neighbors

Stricter metric requiring bidirectional matches:

def compute_mutual_nn(W1, W2):
    sim_matrix = W1_norm @ W2_norm.T

    nn_1to2 = np.argmax(sim_matrix, axis=1)
    nn_2to1 = np.argmax(sim_matrix, axis=0)

    mutual_nn = sum(1 for i, j in enumerate(nn_1to2) if nn_2to1[j] == i)

    return mutual_nn / len(nn_1to2)

3.4 CKA (Centered Kernel Alignment)

Rotation-invariant similarity of representations:

def compute_cka(X1, X2):
    # Center
    X1 = X1 - X1.mean(axis=0)
    X2 = X2 - X2.mean(axis=0)

    # Gram matrices
    K1 = X1 @ X1.T
    K2 = X2 @ X2.T

    # HSIC
    hsic_12 = np.sum(K1 * K2)
    hsic_11 = np.sum(K1 * K1)
    hsic_22 = np.sum(K2 * K2)

    return hsic_12 / np.sqrt(hsic_11 * hsic_22)

Interpretation:

  • 1.0: Identical representations (up to rotation)
  • 0.0: Orthogonal representations
  • Observed: 0.89 (high similarity)

4. Dead Neuron Analysis

4.1 Definition

A neuron is "dead" if its maximum activation across all samples is below threshold:

def count_dead_neurons(z, threshold=1e-6):
    max_activations = z.max(dim=0).values
    return (max_activations < threshold).sum().item()

4.2 Neuron Categories

Category Criterion Interpretation
Dead max_act < 1e-6 Never activates
Rare freq < 1% Highly selective
Common 1% ≤ freq < 50% Moderately active
Frequent freq ≥ 50% General feature

4.3 Root Cause Analysis

Dead neurons arise from:

  1. Insufficient data: Small sample size doesn't activate all features
  2. L1 penalty: Pushes weights toward zero
  3. ReLU saturation: Negative pre-activations stay at zero
  4. Encoder-decoder imbalance: Some directions not useful for reconstruction

5. Statistical Analysis

5.1 Bootstrap Confidence Intervals

def bootstrap_ci(data, n_bootstrap=1000, ci=0.95):
    boot_means = []
    for _ in range(n_bootstrap):
        sample = np.random.choice(data, size=len(data), replace=True)
        boot_means.append(np.mean(sample))

    alpha = 1 - ci
    lower = np.percentile(boot_means, 100 * alpha / 2)
    upper = np.percentile(boot_means, 100 * (1 - alpha / 2))

    return {"mean": np.mean(data), "ci_lower": lower, "ci_upper": upper}

5.2 Permutation Tests

def paired_permutation_test(x, y, n_permutations=10000):
    observed_diff = np.mean(x) - np.mean(y)

    perm_diffs = []
    combined = np.column_stack([x, y])

    for _ in range(n_permutations):
        swaps = np.random.randint(0, 2, size=len(x))
        perm_x = np.where(swaps == 0, combined[:, 0], combined[:, 1])
        perm_y = np.where(swaps == 0, combined[:, 1], combined[:, 0])
        perm_diffs.append(np.mean(perm_x) - np.mean(perm_y))

    p_value = np.mean(np.abs(perm_diffs) >= np.abs(observed_diff))
    return p_value

5.3 Null Distribution Simulation

For alignment significance testing:

def simulate_random_alignment(n_features=8192, d_input=4096, n_simulations=100):
    null_sims = []
    for _ in range(n_simulations):
        W1 = np.random.randn(n_features, d_input)
        W2 = np.random.randn(n_features, d_input)
        W1 = W1 / np.linalg.norm(W1, axis=1, keepdims=True)
        W2 = W2 / np.linalg.norm(W2, axis=1, keepdims=True)

        sim_matrix = W1 @ W2.T
        best_matches = np.max(sim_matrix, axis=1)
        null_sims.append(np.mean(best_matches))

    return np.mean(null_sims), np.std(null_sims)

6. Reproducibility

6.1 Environment

Python 3.11
PyTorch 2.1+
transformers 4.40+
bitsandbytes 0.43+
scipy 1.11+
numpy 1.24+

6.2 Hardware

  • GPU: NVIDIA H100 80GB HBM3
  • VRAM requirements:
    • BF16: ~60GB for 30B model
    • INT8: ~35GB
    • INT4: ~20GB

6.3 Random Seeds

np.random.seed(42)
torch.manual_seed(42)

7. Limitations

  1. Sample size: 227-1446 prompts may not capture full feature diversity
  2. SAE architecture: Only tested 3 variants, others may behave differently
  3. Models tested: 2 models (MoE and dense), more needed for generality
  4. Quantization methods: Only tested bitsandbytes, not GPTQ/AWQ
  5. Downstream tasks: No task-specific evaluation yet

8. Future Improvements

  1. Larger prompt sets: 10K+ diverse prompts
  2. Expert-level analysis: For MoE, analyze per-expert features
  3. Activation patching: Test if aligned features have same causal effects
  4. More quantization methods: GPTQ, AWQ, SmoothQuant
  5. Downstream benchmarks: HumanEval, MMLU, etc.