diff --git a/.gitignore b/.gitignore index 14ac5454..53068cb7 100644 --- a/.gitignore +++ b/.gitignore @@ -77,6 +77,8 @@ target/ # Jupyter Notebook .ipynb_checkpoints +demos/**/graphs +demos/**/graph_files # IPython profile_default/ diff --git a/circuit_tracer/attribution/attribute.py b/circuit_tracer/attribution/attribute.py index 63e7cbf8..409b2d18 100644 --- a/circuit_tracer/attribution/attribute.py +++ b/circuit_tracer/attribution/attribute.py @@ -2,6 +2,7 @@ Unified attribution interface that routes to the correct implementation based on the ReplacementModel backend. """ +from collections.abc import Sequence from typing import TYPE_CHECKING, Literal import torch @@ -9,6 +10,7 @@ from circuit_tracer.graph import Graph if TYPE_CHECKING: + from circuit_tracer.attribution.targets import TargetSpec from circuit_tracer.replacement_model.replacement_model_nnsight import NNSightReplacementModel from circuit_tracer.replacement_model.replacement_model_transformerlens import ( TransformerLensReplacementModel, @@ -19,6 +21,7 @@ def attribute( prompt: str | torch.Tensor | list[int], model: "NNSightReplacementModel | TransformerLensReplacementModel", *, + attribution_targets: "Sequence[str] | Sequence[TargetSpec] | torch.Tensor | None" = None, max_n_logits: int = 10, desired_logit_prob: float = 0.95, batch_size: int = 512, @@ -35,8 +38,16 @@ def attribute( Args: prompt: Text, token ids, or tensor - will be tokenized if str. model: Frozen ``ReplacementModel`` (either nnsight or transformerlens backend) - max_n_logits: Max number of logit nodes. - desired_logit_prob: Keep logits until cumulative prob >= this value. + attribution_targets: Target specification in one of four formats: + - None: Auto-select salient logits based on probability threshold + - torch.Tensor: Tensor of token indices + - Sequence[str]: Token strings (tokenized, auto-computes probability + and unembed vector) + - Sequence[TargetSpec]: Fully specified custom targets (CustomTarget or + tuple[str, float, torch.Tensor]) with arbitrary unembed directions + max_n_logits: Max number of logit nodes (used when attribution_targets is None). + desired_logit_prob: Keep logits until cumulative prob >= this value + (used when attribution_targets is None). batch_size: How many source nodes to process per backward pass. max_feature_nodes: Max number of feature nodes to include in the graph. offload: Method for offloading model parameters to save memory. @@ -55,6 +66,7 @@ def attribute( return attribute_nnsight( prompt=prompt, model=model, # type: ignore[arg-type] + attribution_targets=attribution_targets, max_n_logits=max_n_logits, desired_logit_prob=desired_logit_prob, batch_size=batch_size, @@ -69,6 +81,7 @@ def attribute( return attribute_transformerlens( prompt=prompt, model=model, # type: ignore[arg-type] + attribution_targets=attribution_targets, max_n_logits=max_n_logits, desired_logit_prob=desired_logit_prob, batch_size=batch_size, diff --git a/circuit_tracer/attribution/attribute_nnsight.py b/circuit_tracer/attribution/attribute_nnsight.py index 9e67e650..ea8fdf7b 100644 --- a/circuit_tracer/attribution/attribute_nnsight.py +++ b/circuit_tracer/attribution/attribute_nnsight.py @@ -22,21 +22,27 @@ import logging import time -from typing import Literal +from collections.abc import Sequence +from typing import Literal, cast import torch from tqdm import tqdm +from circuit_tracer.attribution.targets import ( + AttributionTargets, + TargetSpec, + log_attribution_target_info, +) from circuit_tracer.graph import Graph, compute_partial_influences from circuit_tracer.replacement_model.replacement_model_nnsight import NNSightReplacementModel from circuit_tracer.utils.disk_offload import offload_modules -from circuit_tracer.utils.salient_logits import compute_salient_logits def attribute( prompt: str | torch.Tensor | list[int], model: NNSightReplacementModel, *, + attribution_targets: Sequence[str] | Sequence[TargetSpec] | torch.Tensor | None = None, max_n_logits: int = 10, desired_logit_prob: float = 0.95, batch_size: int = 512, @@ -45,13 +51,21 @@ def attribute( verbose: bool = False, update_interval: int = 4, ) -> Graph: - """Compute an attribution graph for *prompt*. + """Compute an attribution graph for *prompt* using NNSight backend. Args: prompt: Text, token ids, or tensor - will be tokenized if str. model: Frozen ``NNSightReplacementModel`` - max_n_logits: Max number of logit nodes. - desired_logit_prob: Keep logits until cumulative prob >= this value. + attribution_targets: Target specification in one of four formats: + - None: Auto-select salient logits based on probability threshold + - torch.Tensor: Tensor of token indices + - Sequence[str]: Token strings (tokenized, auto-computes probability + and unembed vector) + - Sequence[TargetSpec]: Fully specified custom targets (CustomTarget or tuple) + with arbitrary unembed directions + max_n_logits: Max number of logit nodes (used when attribution_targets is None). + desired_logit_prob: Keep logits until cumulative prob >= this value + (used when attribution_targets is None). batch_size: How many source nodes to process per backward pass. max_feature_nodes: Max number of feature nodes to include in the graph. offload: Method for offloading model parameters to save memory. @@ -81,6 +95,7 @@ def attribute( return _run_attribution( model=model, prompt=prompt, + attribution_targets=attribution_targets, max_n_logits=max_n_logits, desired_logit_prob=desired_logit_prob, batch_size=batch_size, @@ -102,6 +117,7 @@ def attribute( def _run_attribution( model: NNSightReplacementModel, prompt, + attribution_targets, max_n_logits: int, desired_logit_prob: float, batch_size: int, @@ -156,15 +172,17 @@ def _run_attribution( n_layers, n_pos, _ = activation_matrix.shape total_active_feats = activation_matrix._nnz() - logit_idx, logit_p, logit_vecs = compute_salient_logits( - ctx.logits[0, -1], - model.unembed_weight, # type: ignore + # Create AttributionTargets using NNSight's unembed_weight accessor + targets = AttributionTargets( + attribution_targets=attribution_targets, + logits=ctx.logits[0, -1], + unembed_proj=cast(torch.Tensor, model.unembed_weight), # NNSight uses unembed_weight + tokenizer=model.tokenizer, max_n_logits=max_n_logits, desired_logit_prob=desired_logit_prob, ) - logger.info( - f"Selected {len(logit_idx)} logits with cumulative probability {logit_p.sum().item():.4f}" - ) + + log_attribution_target_info(targets, attribution_targets, logger) if offload: offload_handles += offload_modules([model.embed_location], offload) @@ -176,8 +194,7 @@ def _run_attribution( offload_handles += offload_modules([model.lm_head], offload) logit_offset = len(feat_layers) + (n_layers + 1) * n_pos - logit_offset = logit_offset - n_logits = len(logit_idx) + n_logits = len(targets) total_nodes = logit_offset + n_logits actual_max_feature_nodes = min(max_feature_nodes or total_active_feats, total_active_feats) @@ -193,8 +210,8 @@ def _run_attribution( logger.info("Phase 3: Computing logit attributions") phase3_start = time.time() i = -1 - for i in range(0, len(logit_idx), batch_size): - batch = logit_vecs[i : i + batch_size] + for i in range(0, len(targets), batch_size): + batch = targets.logit_vectors[i : i + batch_size] rows = ctx.compute_batch( layers=torch.full((batch.shape[0],), n_layers), positions=torch.full((batch.shape[0],), n_pos - 1), @@ -225,7 +242,7 @@ def _run_attribution( pending = torch.arange(total_active_feats) else: influences = compute_partial_influences( - edge_matrix[:st], logit_p, row_to_node_index[:st] + edge_matrix[:st], targets.logit_probabilities, row_to_node_index[:st] ) feature_rank = torch.argsort(influences[:total_active_feats], descending=True).cpu() queue_size = min(update_interval * batch_size, actual_max_feature_nodes - n_visited) @@ -270,8 +287,9 @@ def _run_attribution( graph = Graph( input_string=model.tokenizer.decode(input_ids), input_tokens=input_ids, - logit_tokens=logit_idx, - logit_probabilities=logit_p, + logit_targets=targets.logit_targets, + logit_probabilities=targets.logit_probabilities, + vocab_size=targets.vocab_size, active_features=activation_matrix.indices().T, activation_values=activation_matrix.values(), selected_features=selected_features, diff --git a/circuit_tracer/attribution/attribute_transformerlens.py b/circuit_tracer/attribution/attribute_transformerlens.py index 86c4063c..c82b57e1 100644 --- a/circuit_tracer/attribution/attribute_transformerlens.py +++ b/circuit_tracer/attribution/attribute_transformerlens.py @@ -22,23 +22,29 @@ import logging import time +from collections.abc import Sequence from typing import Literal import torch from tqdm import tqdm +from circuit_tracer.attribution.targets import ( + AttributionTargets, + TargetSpec, + log_attribution_target_info, +) from circuit_tracer.graph import Graph, compute_partial_influences from circuit_tracer.replacement_model.replacement_model_transformerlens import ( TransformerLensReplacementModel, ) from circuit_tracer.utils.disk_offload import offload_modules -from circuit_tracer.utils.salient_logits import compute_salient_logits def attribute( prompt: str | torch.Tensor | list[int], model: TransformerLensReplacementModel, *, + attribution_targets: Sequence[str] | Sequence[TargetSpec] | torch.Tensor | None = None, max_n_logits: int = 10, desired_logit_prob: float = 0.95, batch_size: int = 512, @@ -47,13 +53,21 @@ def attribute( verbose: bool = False, update_interval: int = 4, ) -> Graph: - """Compute an attribution graph for *prompt*. + """Compute an attribution graph for *prompt* using TransformerLens backend. Args: prompt: Text, token ids, or tensor - will be tokenized if str. - model: Frozen ``ReplacementModel`` - max_n_logits: Max number of logit nodes. - desired_logit_prob: Keep logits until cumulative prob >= this value. + model: Frozen ``TransformerLensReplacementModel`` + attribution_targets: Target specification in one of four formats: + - None: Auto-select salient logits based on probability threshold + - torch.Tensor: Tensor of token indices + - Sequence[str]: Token strings (tokenized, auto-computes probability + and unembed vector) + - Sequence[TargetSpec]: Fully specified custom targets (CustomTarget or tuple) + with arbitrary unembed directions + max_n_logits: Max number of logit nodes (used when attribution_targets is None). + desired_logit_prob: Keep logits until cumulative prob >= this value + (used when attribution_targets is None). batch_size: How many source nodes to process per backward pass. max_feature_nodes: Max number of feature nodes to include in the graph. offload: Method for offloading model parameters to save memory. @@ -83,6 +97,7 @@ def attribute( return _run_attribution( model=model, prompt=prompt, + attribution_targets=attribution_targets, max_n_logits=max_n_logits, desired_logit_prob=desired_logit_prob, batch_size=batch_size, @@ -104,6 +119,7 @@ def attribute( def _run_attribution( model, prompt, + attribution_targets, max_n_logits, desired_logit_prob, batch_size, @@ -147,21 +163,22 @@ def _run_attribution( n_layers, n_pos, _ = activation_matrix.shape total_active_feats = activation_matrix._nnz() - logit_idx, logit_p, logit_vecs = compute_salient_logits( - ctx.logits[0, -1], - model.unembed.W_U, + targets = AttributionTargets( + attribution_targets=attribution_targets, + logits=ctx.logits[0, -1], + unembed_proj=model.unembed.W_U, + tokenizer=model.tokenizer, max_n_logits=max_n_logits, desired_logit_prob=desired_logit_prob, ) - logger.info( - f"Selected {len(logit_idx)} logits with cumulative probability {logit_p.sum().item():.4f}" - ) + + log_attribution_target_info(targets, attribution_targets, logger) if offload: offload_handles += offload_modules([model.unembed, model.embed], offload) logit_offset = len(feat_layers) + (n_layers + 1) * n_pos - n_logits = len(logit_idx) + n_logits = len(targets) total_nodes = logit_offset + n_logits max_feature_nodes = min(max_feature_nodes or total_active_feats, total_active_feats) @@ -176,8 +193,8 @@ def _run_attribution( # Phase 3: logit attribution logger.info("Phase 3: Computing logit attributions") phase_start = time.time() - for i in range(0, len(logit_idx), batch_size): - batch = logit_vecs[i : i + batch_size] + for i in range(0, len(targets), batch_size): + batch = targets.logit_vectors[i : i + batch_size] rows = ctx.compute_batch( layers=torch.full((batch.shape[0],), n_layers), positions=torch.full((batch.shape[0],), n_pos - 1), @@ -203,7 +220,7 @@ def _run_attribution( pending = torch.arange(total_active_feats) else: influences = compute_partial_influences( - edge_matrix[:st], logit_p, row_to_node_index[:st] + edge_matrix[:st], targets.logit_probabilities, row_to_node_index[:st] ) feature_rank = torch.argsort(influences[:total_active_feats], descending=True).cpu() queue_size = min(update_interval * batch_size, max_feature_nodes - n_visited) @@ -248,8 +265,9 @@ def _run_attribution( graph = Graph( input_string=model.tokenizer.decode(input_ids), input_tokens=input_ids, - logit_tokens=logit_idx, - logit_probabilities=logit_p, + logit_targets=targets.logit_targets, + logit_probabilities=targets.logit_probabilities, + vocab_size=targets.vocab_size, active_features=activation_matrix.indices().T, activation_values=activation_matrix.values(), selected_features=selected_features, diff --git a/circuit_tracer/attribution/context_nnsight.py b/circuit_tracer/attribution/context_nnsight.py index 63f5e7d3..bf093127 100644 --- a/circuit_tracer/attribution/context_nnsight.py +++ b/circuit_tracer/attribution/context_nnsight.py @@ -103,8 +103,7 @@ def compute_score( proxy = weakref.proxy(self) proxy._batch_buffer[write_index] += einsum( - grads[read_index], - # grads.to(output_vecs.dtype)[read_index], + grads.to(output_vecs.dtype)[read_index], output_vecs, "batch position d_model, position d_model -> position batch", ) diff --git a/circuit_tracer/attribution/targets.py b/circuit_tracer/attribution/targets.py new file mode 100644 index 00000000..9a5292d9 --- /dev/null +++ b/circuit_tracer/attribution/targets.py @@ -0,0 +1,440 @@ +"""Attribution target specification and processing. + +This module provides the AttributionTargets container class and LogitTarget record +structure for specifying and processing attribution targets in the format required +for attribution graph computation. + +Key concepts: +- AttributionTargets: High-level container that encapsulates target specifications +- LogitTarget: Low-level data transfer object (DTO) storing token metadata +- Virtual indices: Technique for representing out-of-vocabulary (OOV) tokens using + synthetic indices >= vocab_size. Required to support arbitrary string token (or functions thereof) + attribution functionality. +""" + +from collections.abc import Sequence +from typing import NamedTuple +import logging + +import torch + + +class LogitTarget(NamedTuple): + """Token metadata for attribution: string representation and vocabulary index.""" + + token_str: str + vocab_idx: int + + +class CustomTarget(NamedTuple): + """A fully specified custom attribution target. + + Attributes: + token_str: Label for this target (e.g., "logit(x)-logit(y)") + prob: Weight/probability for this target + vec: Custom unembed direction vector (d_model,) + """ + + token_str: str + prob: float + vec: torch.Tensor + + +TargetSpec = CustomTarget | tuple[str, float, torch.Tensor] + + +class AttributionTargets: + """Container for processed attribution target specifications. + + Encapsulates target identifiers, softmax probabilities, and demeaned unembedding + vectors needed for attribution graph computation. + + Supports four input formats: + - None: Auto-select salient logits by probability threshold + - torch.Tensor: Specific vocabulary indices (token IDs) + - Sequence[str]: Token strings (tokenized internally) + - Sequence[TargetSpec]: Fully specified custom targets (CustomTarget or raw tuple[str, float, torch.Tensor]) + + Attributes: + logit_targets: List of LogitTarget records with token strings and vocab indices + logit_probabilities: Softmax probabilities for each target (k,) + logit_vectors: Demeaned unembedding vectors (k, d_model) + """ + + def __init__( + self, + attribution_targets: Sequence[str] | Sequence[TargetSpec] | torch.Tensor | None, + logits: torch.Tensor, + unembed_proj: torch.Tensor, + tokenizer, + *, + max_n_logits: int = 10, + desired_logit_prob: float = 0.95, + ): + """Build attribution targets from user specification. + + Args: + attribution_targets: Target specification in one of four formats: + - None: Auto-select salient logits based on probability threshold + - torch.Tensor: Tensor of vocabulary token IDs + - Sequence[str]: Token strings (tokenized, then auto-computes probability & vector) + - Sequence[TargetSpec]: Fully specified custom targets (CustomTarget or + tuple[str, float, torch.Tensor]) with custom probability and unembed direction + (uses virtual index for OOV tokens) + logits: ``(d_vocab,)`` logit vector for single position + unembed_proj: ``(d_model, d_vocab)`` unembedding matrix + tokenizer: Tokenizer for string→int conversion + max_n_logits: Max targets when auto-selecting (salient mode) + desired_logit_prob: Probability threshold for salient mode + """ + # Store tokenizer ref for decoding vocab indices to token strings + self.tokenizer = tokenizer + ctor_shared = {"logits": logits, "unembed_proj": unembed_proj, "tokenizer": tokenizer} + + # Dispatch to appropriate constructor based on input type + if attribution_targets is None: + salient_ctor = {"max_n_logits": max_n_logits, "desired_logit_prob": desired_logit_prob} + attr_spec = self._from_salient(**salient_ctor, **ctor_shared) + elif isinstance(attribution_targets, torch.Tensor): + attr_spec = self._from_indices(indices=attribution_targets, **ctor_shared) + elif isinstance(attribution_targets, Sequence): + if not attribution_targets: + raise ValueError("attribution_targets sequence cannot be empty") + first = attribution_targets[0] + if isinstance(first, str): + attr_spec = self._from_str(token_strs=attribution_targets, **ctor_shared) # type: ignore[arg-type] + elif isinstance(first, (tuple, CustomTarget)): + attr_spec = self._from_tuple(target_tuples=attribution_targets, **ctor_shared) # type: ignore[arg-type] + else: + raise TypeError( + f"Sequence elements must be str or TargetSpec (CustomTarget or " + f"tuple[str, float, Tensor]), got {type(first)}" + ) + else: + raise TypeError( + f"attribution_targets must be None, torch.Tensor, Sequence[str], " + f"or Sequence[TargetSpec], got {type(attribution_targets)}" + ) + self.logit_targets, self.logit_probabilities, self.logit_vectors = attr_spec + + def __len__(self) -> int: + """Number of attribution targets.""" + return len(self.logit_targets) + + def __repr__(self) -> str: + """String representation showing key info.""" + if len(self.logit_targets) > 3: + targets_preview = self.logit_targets[:3] + suffix = "..." + else: + targets_preview = self.logit_targets + suffix = "" + return f"AttributionTargets(n={len(self)}, targets={targets_preview}{suffix})" + + @property + def tokens(self) -> list[str]: + """Get token strings for all targets. + + Returns: + List of token strings (decoded vocab tokens or arbitrary strings) + """ + return [target.token_str for target in self.logit_targets] + + @property + def vocab_size(self) -> int: + """Vocabulary size from the tokenizer. + + Returns: + Vocabulary size for determining virtual vs real indices + """ + return self.tokenizer.vocab_size + + @property + def token_ids(self) -> torch.Tensor: + """Tensor of vocabulary indices (including virtual indices >= vocab_size). + + Returns a torch.Tensor of vocab indices on the same device as other tensors, + suitable for indexing into logit vectors or embeddings. + + Returns: + torch.Tensor: Long tensor of vocabulary indices + """ + return torch.tensor( + [target.vocab_idx for target in self.logit_targets], + dtype=torch.long, + device=self.logit_probabilities.device, + ) + + def to(self, device: str | torch.device) -> "AttributionTargets": + """Transfer AttributionTargets to specified device. + + Only moves torch.Tensor attributes (logit_probabilities, logit_vectors); + logit_targets list stays unchanged. + + Args: + device: Target device (e.g., "cuda", "cpu") + + Returns: + Self with tensors on new device + """ + self.logit_probabilities = self.logit_probabilities.to(device) + self.logit_vectors = self.logit_vectors.to(device) + return self + + @staticmethod + def _from_salient( + logits: torch.Tensor, + unembed_proj: torch.Tensor, + max_n_logits: int, + desired_logit_prob: float, + tokenizer, + ) -> tuple[list[LogitTarget], torch.Tensor, torch.Tensor]: + """Auto-select salient logits by cumulative probability. + + Picks the smallest set of logits whose cumulative probability + exceeds the threshold, up to max_n_logits. + + Args: + logits: ``(d_vocab,)`` logit vector + unembed_proj: ``(d_model, d_vocab)`` unembedding matrix + max_n_logits: Hard cap on number of logits + desired_logit_prob: Cumulative probability threshold + tokenizer: Tokenizer for decoding vocab indices to strings + + Returns: + Tuple of (logit_targets, probabilities, vectors) where logit_targets + contains LogitTarget instances with actual vocab indices + """ + probs = torch.softmax(logits, dim=-1) + top_p, top_idx = torch.topk(probs, max_n_logits) + cutoff = int(torch.searchsorted(torch.cumsum(top_p, 0), desired_logit_prob)) + 1 + indices, probs, vecs = AttributionTargets._compute_logit_vecs( + top_idx[:cutoff], logits, unembed_proj + ) + logit_targets = [ + LogitTarget(token_str=tokenizer.decode(idx), vocab_idx=idx) for idx in indices.tolist() + ] + return logit_targets, probs, vecs + + @staticmethod + def _from_indices( + indices: torch.Tensor, + logits: torch.Tensor, + unembed_proj: torch.Tensor, + tokenizer, + ) -> tuple[list[LogitTarget], torch.Tensor, torch.Tensor]: + """Construct from specific vocabulary indices. + + Args: + indices: ``(k,)`` tensor of vocabulary indices + logits: ``(d_vocab,)`` logit vector + unembed_proj: ``(d_model, d_vocab)`` unembedding matrix + tokenizer: Tokenizer for decoding vocab indices to strings + + Returns: + Tuple of (logit_targets, probabilities, vectors) where logit_targets + contains LogitTarget instances with actual vocab indices + + Raises: + ValueError: If any index is out of vocabulary range + """ + vocab_size = logits.shape[0] + + # Validate all indices are within vocab range + if (indices < 0).any() or (indices >= vocab_size).any(): + invalid = indices[(indices < 0) | (indices >= vocab_size)] + raise ValueError( + f"Token indices must be in range [0, {vocab_size}), " + f"but found invalid indices: {invalid.tolist()}" + ) + + indices_out, probs, vecs = AttributionTargets._compute_logit_vecs( + indices, logits, unembed_proj + ) + + # Create LogitTarget instances with decoded token strings + logit_targets = [ + LogitTarget(token_str=tokenizer.decode(idx), vocab_idx=idx) + for idx in indices_out.tolist() + ] + return logit_targets, probs, vecs + + @staticmethod + def _from_str( + token_strs: Sequence[str], + logits: torch.Tensor, + unembed_proj: torch.Tensor, + tokenizer, + ) -> tuple[list[LogitTarget], torch.Tensor, torch.Tensor]: + """Construct from a sequence of token strings. + + Each string is tokenized and its probability/vector auto-computed. + + Args: + token_strs: Sequence of token strings + logits: ``(d_vocab,)`` logit vector + unembed_proj: Unembedding matrix + tokenizer: Tokenizer for string→int conversion + + Returns: + Tuple of (logit_targets, probabilities, vectors) + """ + vocab_size = logits.shape[0] + indices = [] + for token_str in token_strs: + try: + ids = tokenizer.encode(token_str, add_special_tokens=False) + except Exception as e: + raise ValueError( + f"Failed to encode string token {token_str!r} using tokenizer: {e}" + ) from e + if not ids: + raise ValueError(f"String token {token_str!r} encoded to empty token sequence.") + if len(ids) > 1: + raise ValueError( + f"String token {token_str!r} encoded to {len(ids)} tokens " + f"(IDs: {ids}). Each string must map to exactly one token. " + f"Consider providing single-token strings." + ) + token_id = ids[0] + assert 0 <= token_id < vocab_size, ( + f"Token {token_str!r} resolved to index {token_id}, " + f"out of vocabulary range [0, {vocab_size})" + ) + indices.append(token_id) + return AttributionTargets._from_indices( + indices=torch.tensor(indices, dtype=torch.long), + logits=logits, + unembed_proj=unembed_proj, + tokenizer=tokenizer, + ) + + @staticmethod + def _validate_custom_target( + target: TargetSpec, + ) -> CustomTarget: + """Validate and normalize a custom target. + + Args: + target: A CustomTarget or raw (token_str, prob, vec) tuple + + Returns: + Validated CustomTarget instance + + Raises: + ValueError: If the tuple has wrong length or element types + """ + if not isinstance(target, CustomTarget): + if len(target) != 3: + raise ValueError( + f"Tuple targets must have exactly 3 elements " + f"(token_str, probability, vector), got {len(target)}" + ) + token_str, prob, vec = target + else: + token_str, prob, vec = target.token_str, target.prob, target.vec + if not isinstance(token_str, str): + raise TypeError(f"Custom target token_str must be str, got {type(token_str)}") + if not isinstance(prob, (int, float)): + raise TypeError(f"Custom target prob must be int or float, got {type(prob)}") + if not isinstance(vec, torch.Tensor): + raise TypeError(f"Custom target vec must be torch.Tensor, got {type(vec)}") + return CustomTarget(token_str=token_str, prob=float(prob), vec=vec) + + @staticmethod + def _from_tuple( + target_tuples: Sequence[TargetSpec], + logits: torch.Tensor, + unembed_proj: torch.Tensor, + tokenizer, + ) -> tuple[list[LogitTarget], torch.Tensor, torch.Tensor]: + """Construct from fully specified custom targets. + + Each target provides (token_str, prob, vec) for an arbitrary + attribution direction that may not correspond to a vocabulary token. + + Args: + target_tuples: Sequence of CustomTarget or raw tuple instances + logits: ``(d_vocab,)`` logit vector (used for vocab_size) + unembed_proj: Unembedding matrix (unused but kept for interface consistency) + tokenizer: Tokenizer (unused but kept for interface consistency) + + Returns: + Tuple of (logit_targets, probabilities, vectors) + """ + vocab_size = logits.shape[0] + logit_targets, probs, vecs = [], [], [] + for position, target in enumerate(target_tuples): + validated = AttributionTargets._validate_custom_target(target) + virtual_idx = vocab_size + position + logit_targets.append(LogitTarget(token_str=validated.token_str, vocab_idx=virtual_idx)) + probs.append(validated.prob) + vecs.append(validated.vec) + return logit_targets, torch.tensor(probs), torch.stack(vecs) + + @staticmethod + def _compute_logit_vecs( + indices: torch.Tensor, + logits: torch.Tensor, + unembed_proj: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute probabilities and demeaned vectors for indices. + + Args: + indices: ``(k,)`` vocabulary indices to compute vectors for + logits: ``(d_vocab,)`` logit vector for single position + unembed_proj: ``(d_model, d_vocab)`` or ``(d_vocab, d_model)`` unembedding matrix + (orientation auto-detected by matching vocab dimension to logits) + + Returns: + Tuple of: + * indices - ``(k,)`` vocabulary ids (same as input) + * probabilities - ``(k,)`` softmax probabilities + * demeaned_vecs - ``(k, d_model)`` unembedding columns, demeaned + """ + probs = torch.softmax(logits, dim=-1) + selected_probs = probs[indices] + + # Auto-detect matrix orientation by matching against vocabulary size + d_vocab = logits.shape[0] + if unembed_proj.shape[0] == d_vocab: + # Shape is (d_vocab, d_model) – first axis is vocabulary (e.g., NNSight) + cols = unembed_proj[indices] # (k, d_model) + demean = unembed_proj.mean(dim=0, keepdim=True) # (1, d_model) + demeaned_vecs = cols - demean # (k, d_model) + else: + # Shape is (d_model, d_vocab) – second axis is vocabulary (e.g., TransformerLens) + cols = unembed_proj[:, indices] # (d_model, k) + demean = unembed_proj.mean(dim=-1, keepdim=True) # (d_model, 1) + demeaned_vecs = (cols - demean).T # (k, d_model) + + return indices, selected_probs, demeaned_vecs + + +def log_attribution_target_info( + targets: "AttributionTargets", + attribution_targets: Sequence[str] | Sequence[TargetSpec] | torch.Tensor | None, + logger: logging.Logger, +) -> None: + """Log information about attribution targets. + + Args: + targets: AttributionTargets instance with processed targets + attribution_targets: Original attribution_targets specification + logger: Logger to use for output + """ + prob_sum = targets.logit_probabilities.sum().item() + if attribution_targets is None: + target_desc = "salient logits" + weight_desc = "cumulative probability" + elif ( + isinstance(attribution_targets, Sequence) + and attribution_targets + and isinstance(attribution_targets[0], (tuple, CustomTarget)) + ): + target_desc = "custom attribution targets" + weight_desc = "total weight" + else: + target_desc = "specified logit targets" + weight_desc = "cumulative probability" + logger.info(f"Using {len(targets)} {target_desc} with {weight_desc} {prob_sum:.4f}") diff --git a/circuit_tracer/graph.py b/circuit_tracer/graph.py index 9dc11945..ba98e98b 100644 --- a/circuit_tracer/graph.py +++ b/circuit_tracer/graph.py @@ -1,4 +1,9 @@ +"""Graph data structures for attribution results.""" + +from __future__ import annotations + from typing import NamedTuple +import warnings import torch @@ -7,19 +12,22 @@ UnifiedConfig, ) from circuit_tracer.utils import get_default_device +from circuit_tracer.attribution.targets import LogitTarget class Graph: input_string: str input_tokens: torch.Tensor - logit_tokens: torch.Tensor + logit_targets: list[LogitTarget] active_features: torch.Tensor adjacency_matrix: torch.Tensor selected_features: torch.Tensor activation_values: torch.Tensor logit_probabilities: torch.Tensor + vocab_size: int cfg: UnifiedConfig scan: str | list[str] | None + n_pos: int def __init__( self, @@ -28,11 +36,12 @@ def __init__( active_features: torch.Tensor, adjacency_matrix: torch.Tensor, cfg, - logit_tokens: torch.Tensor, - logit_probabilities: torch.Tensor, selected_features: torch.Tensor, activation_values: torch.Tensor, + logit_targets: list[LogitTarget], + logit_probabilities: torch.Tensor, scan: str | list[str] | None = None, + vocab_size: int | None = None, ): """ A graph object containing the adjacency matrix describing the direct effect of each @@ -44,30 +53,35 @@ def __init__( Args: input_string (str): The input string attributed. - input_tokens (List[str]): The input tokens attributed. + input_tokens (torch.Tensor): The input tokens attributed. active_features (torch.Tensor): A tensor of shape (n_active_features, 3) containing the indices (layer, pos, feature_idx) of the non-zero features of the model on the given input string. adjacency_matrix (torch.Tensor): The adjacency matrix. Organized as [active_features, error_nodes, embed_nodes, logit_nodes], where there are model.cfg.n_layers * len(input_tokens) error nodes, len(input_tokens) embed - nodes, len(logit_tokens) logit nodes. The rows represent target nodes, while + nodes, len(logit_targets) logit nodes. The rows represent target nodes, while columns represent source nodes. - cfg (HookedTransformerConfig): The cfg of the model. - logit_tokens (List[str]): The logit tokens attributed from. - logit_probabilities (torch.Tensor): The probabilities of each logit token, given - the input string. + cfg: The cfg of the model. + selected_features (torch.Tensor): Indices into active_features for selected nodes. + activation_values (torch.Tensor): Activation values for selected features. + logit_targets: List of LogitTarget records describing each logit target. + logit_probabilities: Tensor of logit target probabilities/weights. scan (Union[str,List[str]] | None, optional): The identifier of the transcoders used in the graph. Without a scan, the graph cannot be uploaded (since we won't know what transcoders were used). Defaults to None + vocab_size: Vocabulary size. If not provided, defaults to cfg.d_vocab. """ + self.logit_targets = logit_targets + self.logit_probabilities = logit_probabilities + self.vocab_size = vocab_size if vocab_size is not None else cfg.d_vocab + self.input_string = input_string self.adjacency_matrix = adjacency_matrix + # Convert cfg to UnifiedConfig (handles both HookedTransformerConfig and NNSight configs) self.cfg = convert_nnsight_config_to_transformerlens(cfg) self.n_pos = len(input_tokens) self.active_features = active_features - self.logit_tokens = logit_tokens - self.logit_probabilities = logit_probabilities self.input_tokens = input_tokens if scan is None: print("Graph loaded without scan to identify it. Uploading will not be possible.") @@ -83,9 +97,41 @@ def to(self, device): """ self.adjacency_matrix = self.adjacency_matrix.to(device) self.active_features = self.active_features.to(device) - self.logit_tokens = self.logit_tokens.to(device) + # logit_targets is list[LogitTarget], no device transfer needed self.logit_probabilities = self.logit_probabilities.to(device) + @property + def logit_token_ids(self) -> torch.Tensor: + """Tensor of logit target token IDs. + + Returns token IDs for logit targets on the same device as other graph tensors. + + Returns: + torch.Tensor: Long tensor of vocabulary indices + """ + return torch.tensor( + [target.vocab_idx for target in self.logit_targets], + dtype=torch.long, + device=self.logit_probabilities.device, + ) + + @property + def logit_tokens(self) -> torch.Tensor: + """Get logit target token IDs tensor (legacy compatibility). + + .. deprecated:: + Use `logit_token_ids` property instead. This is an alias for backward compatibility. + + Raises: + ValueError: If any targets have virtual indices + """ + warnings.warn( + "logit_tokens property is deprecated. Use logit_token_ids property instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.logit_token_ids + def to_pt(self, path: str): """Saves the graph at the given path @@ -97,8 +143,9 @@ def to_pt(self, path: str): "adjacency_matrix": self.adjacency_matrix, "cfg": self.cfg, "active_features": self.active_features, - "logit_tokens": self.logit_tokens, + "logit_targets": self.logit_targets, "logit_probabilities": self.logit_probabilities, + "vocab_size": self.vocab_size, "input_tokens": self.input_tokens, "selected_features": self.selected_features, "activation_values": self.activation_values, @@ -110,6 +157,9 @@ def to_pt(self, path: str): def from_pt(path: str, map_location="cpu") -> "Graph": """Load a graph (saved using graph.to_pt) from a .pt file at the given path. + Handles backward compatibility with older serialized graphs that stored + logit_targets as a torch.Tensor of token IDs. + Args: path (str): The path of the Graph to load map_location (str, optional): the device to load the graph onto. @@ -119,6 +169,12 @@ def from_pt(path: str, map_location="cpu") -> "Graph": Graph: the Graph saved at the specified path """ d = torch.load(path, weights_only=False, map_location=map_location) + # BC: convert legacy tensor logit_targets to LogitTarget list + lt = d.get("logit_targets") + if isinstance(lt, torch.Tensor): + d["logit_targets"] = [ + LogitTarget(token_str="", vocab_idx=int(idx)) for idx in lt.tolist() + ] return Graph(**d) @@ -199,7 +255,7 @@ def prune_graph( # Extract dimensions n_tokens = len(graph.input_tokens) - n_logits = len(graph.logit_tokens) + n_logits = len(graph.logit_targets) n_features = len(graph.selected_features) logit_weights = torch.zeros( @@ -276,11 +332,11 @@ def compute_graph_scores(graph: Graph) -> tuple[float, float]: reconstruction where all computation flows through interpretable features. Lower scores indicate more reliance on error nodes, suggesting incomplete feature coverage. """ - n_logits = len(graph.logit_tokens) + n_logits = len(graph.logit_targets) n_tokens = len(graph.input_tokens) n_features = len(graph.selected_features) error_start = n_features - error_end = error_start + n_tokens * graph.cfg.n_layers # type: ignore + error_end = error_start + n_tokens * graph.cfg.n_layers token_end = error_end + n_tokens logit_weights = torch.zeros( @@ -309,7 +365,24 @@ def compute_partial_influences( max_iter: int = 128, device=None, ): - """Compute partial influences using power iteration method.""" + """Compute partial influences using power iteration method. + + This function calculates the influence of each node on the output logits + based on the edge weights in the graph. + + Args: + edge_matrix: The edge weight matrix. + logit_p: The logit probabilities. + row_to_node_index: Mapping from row indices to node indices. + max_iter: Maximum number of iterations for convergence. + device: Device to perform computation on. + + Returns: + torch.Tensor: Influence values for each node. + + Raises: + RuntimeError: If computation fails to converge within max_iter. + """ device = device or get_default_device() normalized_matrix = torch.empty_like(edge_matrix, device=device).copy_(edge_matrix) diff --git a/circuit_tracer/replacement_model/replacement_model_nnsight.py b/circuit_tracer/replacement_model/replacement_model_nnsight.py index 964c856e..8f46aeaf 100644 --- a/circuit_tracer/replacement_model/replacement_model_nnsight.py +++ b/circuit_tracer/replacement_model/replacement_model_nnsight.py @@ -3,7 +3,7 @@ from collections.abc import Sequence from contextlib import contextmanager from functools import partial -from typing import Callable, Iterator, Literal +from typing import Callable, Iterator, Literal, cast import torch from torch import nn @@ -241,8 +241,12 @@ def _configure_replacement_model( self._embed_location = nnsight_config.embed_location # these are real weights, not envoys - self.embed_weight = self._resolve_attr(self, nnsight_config.embed_weight) - self.unembed_weight = self._resolve_attr(self, nnsight_config.unembed_weight) + self.embed_weight = cast( + torch.Tensor, self._resolve_attr(self, nnsight_config.embed_weight) + ) + self.unembed_weight = cast( + torch.Tensor, self._resolve_attr(self, nnsight_config.unembed_weight) + ) self.scan = transcoder_set.scan # Make sure the replacement model is entirely frozen by default. diff --git a/circuit_tracer/utils/__init__.py b/circuit_tracer/utils/__init__.py index 078ee3fb..e23c6fd5 100644 --- a/circuit_tracer/utils/__init__.py +++ b/circuit_tracer/utils/__init__.py @@ -7,6 +7,4 @@ def get_default_device() -> torch.device: return torch.device("cuda" if torch.cuda.is_available() else "cpu") -__all__ = [ - "create_graph_files", -] +__all__ = ["create_graph_files", "get_default_device"] diff --git a/circuit_tracer/utils/create_graph_files.py b/circuit_tracer/utils/create_graph_files.py index 9d8f1c26..6691464e 100644 --- a/circuit_tracer/utils/create_graph_files.py +++ b/circuit_tracer/utils/create_graph_files.py @@ -1,17 +1,25 @@ +from __future__ import annotations + import logging import os import time + import torch from transformers import AutoTokenizer from circuit_tracer.frontend.graph_models import Metadata, Model, Node, QParams from circuit_tracer.frontend.utils import add_graph_metadata +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from circuit_tracer.graph import Graph + logger = logging.getLogger(__name__) -def load_graph_data(file_path): +def load_graph_data(file_path) -> Graph: """Load graph data from a PyTorch file.""" from circuit_tracer.graph import Graph @@ -22,7 +30,7 @@ def load_graph_data(file_path): return graph -def create_nodes(graph, node_mask, tokenizer, cumulative_scores): +def create_nodes(graph: Graph, node_mask, tokenizer, cumulative_scores): """Create all nodes for the graph.""" start_time = time.time() @@ -30,7 +38,7 @@ def create_nodes(graph, node_mask, tokenizer, cumulative_scores): n_features = len(graph.selected_features) layers = graph.cfg.n_layers - error_end_idx = n_features + graph.n_pos * layers # type: ignore + error_end_idx = n_features + graph.n_pos * layers token_end_idx = error_end_idx + len(graph.input_tokens) for node_idx in node_mask.nonzero().squeeze().tolist(): @@ -53,10 +61,16 @@ def create_nodes(graph, node_mask, tokenizer, cumulative_scores): ) elif node_idx in range(token_end_idx, len(cumulative_scores)): pos = node_idx - token_end_idx + + # vocab_idx can be either a valid token_id (< vocab_size) or a virtual + # index (>= vocab_size) for arbitrary strings/functions thereof. The virtual indices + # encode the position in the logit_targets list as: vocab_size + position. + token, vocab_idx = graph.logit_targets[pos] + nodes[node_idx] = Node.logit_node( pos=graph.n_pos - 1, - vocab_idx=graph.logit_tokens[pos], - token=tokenizer.decode(graph.logit_tokens[pos]), + vocab_idx=vocab_idx, + token=token, target_logit=pos == 0, token_prob=graph.logit_probabilities[pos].item(), num_layers=layers, @@ -68,7 +82,7 @@ def create_nodes(graph, node_mask, tokenizer, cumulative_scores): return nodes -def create_used_nodes_and_edges(graph, nodes, edge_mask): +def create_used_nodes_and_edges(graph: Graph, nodes, edge_mask): """Filter to only used nodes and create edges.""" start_time = time.time() edges = edge_mask.numpy() @@ -102,7 +116,7 @@ def create_used_nodes_and_edges(graph, nodes, edge_mask): return used_nodes, used_edges -def build_model(graph, used_nodes, used_edges, slug, scan, node_threshold, tokenizer): +def build_model(graph: Graph, used_nodes, used_edges, slug, scan, node_threshold, tokenizer): """Build the full model object.""" start_time = time.time() @@ -145,13 +159,14 @@ def build_model(graph, used_nodes, used_edges, slug, scan, node_threshold, token def create_graph_files( - graph_or_path, + graph_or_path: Graph | str, slug: str, output_path, scan=None, node_threshold=0.8, edge_threshold=0.98, ): + # Import Graph/prune_graph locally to avoid circular import at module import time from circuit_tracer.graph import Graph, prune_graph total_start_time = time.time() diff --git a/circuit_tracer/utils/demo_utils.py b/circuit_tracer/utils/demo_utils.py index 9e9c2c97..256ecc7b 100644 --- a/circuit_tracer/utils/demo_utils.py +++ b/circuit_tracer/utils/demo_utils.py @@ -1,3 +1,4 @@ +import gc import html import json import urllib.parse @@ -6,9 +7,379 @@ import torch from IPython.display import HTML, display +from circuit_tracer.attribution.targets import CustomTarget +from circuit_tracer.graph import compute_node_influence + Feature = namedtuple("Feature", ["layer", "pos", "feature_idx"]) +def get_unembed_vecs(model, token_ids: list[int], backend: str) -> list[torch.Tensor]: + """Extract unembedding column vectors for the given token IDs. + + Handles the orientation difference between TransformerLens (d_model, d_vocab) + and NNSight (d_vocab, d_model) unembedding matrices. + + Args: + model: A ``ReplacementModel`` instance. + token_ids: Vocabulary indices whose unembed columns to extract. + backend: ``"transformerlens"`` or ``"nnsight"``. + + Returns: + List of 1-D tensors, one per token ID, each of shape ``(d_model,)``. + """ + unembed = model.unembed.W_U if backend == "transformerlens" else model.unembed_weight + d_vocab = model.tokenizer.vocab_size + if unembed.shape[0] == d_vocab: + return [unembed[tid] for tid in token_ids] + return [unembed[:, tid] for tid in token_ids] + + +def cleanup_cuda() -> None: + """Run garbage collection and free CUDA cache.""" + gc.collect() + torch.cuda.empty_cache() + + +def get_top_features(graph, n: int = 10) -> tuple[list[tuple[int, int, int]], list[float]]: + """Extract the top-N feature nodes from the graph by total multi-hop influence. + + Uses ``compute_node_influence`` to rank features by their total effect + on *all* logit targets (direct + indirect paths), weighted by each + target's probability. + + Args: + graph: A Graph object with ``adjacency_matrix``, ``selected_features``, + ``active_features``, ``logit_targets``, and ``logit_probabilities``. + n: Number of top features to return. + + Returns: + Tuple of (features, scores) where *features* is a list of + ``(layer, pos, feature_idx)`` tuples and *scores* is the + corresponding influence values. + """ + n_logits = len(graph.logit_targets) + n_features = len(graph.selected_features) + + # Build logit weight vector + logit_weights = torch.zeros( + graph.adjacency_matrix.shape[0], device=graph.adjacency_matrix.device + ) + logit_weights[-n_logits:] = graph.logit_probabilities + + # Multi-hop influence across all logit targets + node_influence = compute_node_influence(graph.adjacency_matrix, logit_weights) + feature_influence = node_influence[:n_features] + + top_k = min(n, n_features) + top_values, top_indices = torch.topk(feature_influence, top_k) + + features = [ + tuple(graph.active_features[graph.selected_features[i]].tolist()) for i in top_indices + ] + scores = top_values.tolist() + return features, scores + + +def display_top_features_comparison( + feature_sets: dict[str, list[tuple[int, int, int]]], + scores_sets: dict[str, list[float]] | None = None, + neuronpedia_model: str | None = None, + neuronpedia_set: str = "gemmascope-transcoder-16k", +): + """Display top features from multiple attribution configurations side by side. + + Args: + feature_sets: Mapping from config label to list of ``(layer, pos, feat_idx)`` tuples. + scores_sets: Optional mapping from config label to list of attribution scores. + If ``None``, scores are omitted from the display. + neuronpedia_model: Neuronpedia model slug (e.g. ``"gemma-2-2b"``). + When provided, feature indices become clickable links. + neuronpedia_set: Neuronpedia set name (default ``"gemmascope-transcoder-16k"``). + """ + labels = list(feature_sets.keys()) + colors = ["#2471A3", "#27AE60", "#8E44AD", "#E67E22", "#C0392B", "#16A085"] + + style = """ + + """ + + body = '
' + for i, label in enumerate(labels): + color = colors[i % len(colors)] + features = feature_sets[label] + scores = scores_sets.get(label) if scores_sets else None + body += '
' + body += ( + f'
{html.escape(label)}
' + ) + body += "" + if scores is not None: + body += "" + body += "" + for j, (layer, pos, feat_idx) in enumerate(features): + score_cell = f"" if scores is not None else "" + if neuronpedia_model is not None: + np_url = ( + f"https://www.neuronpedia.org/{neuronpedia_model}/" + f"{layer}-{neuronpedia_set}/{feat_idx}" + ) + feat_link = f'{feat_idx}' + else: + feat_link = str(feat_idx) + node_cell = f'' + body += f"{node_cell}{score_cell}" + body += "
#NodeScore
{scores[j]:.4f}({layer}, {pos}, {feat_link})
{j + 1}
" + body += "
" + + display(HTML(style + body)) + + +def display_attribution_config( + token_pairs: list[tuple[str, int]], + target_pairs: list[tuple[str, CustomTarget]], +) -> None: + """Display token-mapping and custom-target summary tables. + + Args: + token_pairs: List of ``(token_str, vocab_id)`` pairs for the Token Mappings table. + target_pairs: List of ``(kind_label, target)`` pairs for the Attribution Targets + table, where each ``target`` is a CustomTarget with ``.token_str`` and ``.prob`` attributes. + """ + th_l = "padding:5px 14px 5px 6px; border-bottom:2px solid #888; text-align:left; white-space:nowrap" + th_r = "padding:5px 14px 5px 6px; border-bottom:2px solid #888; text-align:right; white-space:nowrap" + td_l = "padding:4px 14px 4px 6px; border-bottom:1px solid #ddd; text-align:left" + td_r = "padding:4px 14px 4px 6px; border-bottom:1px solid #ddd; text-align:right" + + # ── Token Mappings ──────────────────────────────────────────────────────── + token_rows = "".join( + "" + "" + html.escape(tok) + "" + "" + str(vid) + "" + "" + for tok, vid in token_pairs + ) + display( + HTML( + "Token Mappings" + "" + "" + "" + "" + "" + "" + token_rows + "" + "
TokenVocab ID
" + ) + ) + + # ── Attribution Targets ─────────────────────────────────────────────────── + target_rows = "".join( + "" + "" + html.escape(kind) + "" + "" + html.escape(tgt.token_str) + "" + "" + f"{tgt.prob * 100:.3f}%" + "" + "" + for kind, tgt in target_pairs + ) + display( + HTML( + "Attribution Targets" + "" + "" + "" + "" + "" + "" + "" + target_rows + "" + "
TargetLabelProbability
" + ) + ) + + +def display_token_probs( + logits: torch.Tensor, + token_ids: list[int], + labels: list[str], + title: str = "", +) -> None: + """Display softmax probabilities for specific tokens as a styled HTML table. + + Probabilities are shown as percentages (3 decimal places) when ≥ 0.001, + otherwise in scientific notation (2 significant figures). + + Args: + logits: Raw logits tensor (at least 2-D; last position is used). + token_ids: Vocabulary indices to display. + labels: Human-readable label for each token. + title: Optional heading rendered above the table. + """ + probs = torch.softmax(logits.squeeze(0)[-1].float(), dim=-1) + + def _fmt(p: float) -> str: + return f"{p * 100:.3f}%" if p >= 1e-3 else f"{p:.2e}" + + rows = "" + for i, (tid, label) in enumerate(zip(token_ids, labels)): + p = probs[tid].item() + logit_val = logits.squeeze(0)[-1, tid].item() + row_class = "even-row" if i % 2 == 0 else "odd-row" + rows += ( + f'' + f'{html.escape(label)}' + f'{_fmt(p)}' + f'{logit_val:.4f}' + f"\n" + ) + + title_html = ( + f'
{html.escape(title)}
' + if title + else "" + ) + + markup = f""" +
+ {title_html} + + + + + + + + + + {rows} + +
TokenProbabilityLogit
+
+ """ + display(HTML(markup)) + + +def display_ablation_chart( + groups: dict[str, dict[str, float]], + logit_diffs: dict[str, float] | None = None, + title: str = "", + colors: list[str] | None = None, +) -> None: + """Display ablation results as a grouped bar chart with logit-difference line. + + Args: + groups: Mapping from group label (e.g. ``"Baseline"``) to a dict + of ``{token_label: probability}``. + logit_diffs: Optional mapping from group label to logit difference, + plotted as a dashed line on a secondary y-axis. + title: Chart title. + colors: Bar colours for each token. Defaults to a built-in palette. + """ + import matplotlib.pyplot as plt + import numpy as np + + group_labels = list(groups.keys()) + token_labels = list(next(iter(groups.values())).keys()) + n_groups = len(group_labels) + n_tokens = len(token_labels) + + if colors is None: + colors = ["#2471A3", "#E67E22", "#27AE60", "#C0392B", "#8E44AD"][:n_tokens] + + x = np.arange(n_groups) + width = 0.8 / n_tokens + + fig, ax1 = plt.subplots(figsize=(8, 5.0)) + + for i, tok in enumerate(token_labels): + vals = [groups[g].get(tok, 0) for g in group_labels] + offset = (i - (n_tokens - 1) / 2) * width + bars = ax1.bar( + x + offset, + vals, + width * 0.9, + label=tok, + color=colors[i], + alpha=0.85, + ) + for bar, v in zip(bars, vals): + ax1.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + 0.005, + f"{v:.3f}", + ha="center", + va="bottom", + fontsize=8, + ) + + ax1.set_ylabel("Probability") + ax1.set_xticks(x) + ax1.set_xticklabels(group_labels) + max_prob = max(max(groups[g].get(t, 0) for t in token_labels) for g in group_labels) + ax1.set_ylim(0, max_prob * 1.4) + + if logit_diffs is not None: + ax2 = ax1.twinx() + diff_vals = [logit_diffs.get(g, 0) for g in group_labels] + ax2.plot( + x, + diff_vals, + "k--o", + label="Logit diff", + linewidth=1.5, + markersize=5, + ) + ax2.set_ylabel("Logit difference") + ax2.legend(loc="upper right") + + ax1.legend(loc="upper left") + if title: + ax1.set_title(title, fontsize=13, fontweight="bold") + fig.tight_layout() + plt.show() + + def get_topk(logits: torch.Tensor, tokenizer, k: int = 5): probs = torch.softmax(logits.squeeze()[-1], dim=-1) topk = torch.topk(probs, k) @@ -16,10 +387,29 @@ def get_topk(logits: torch.Tensor, tokenizer, k: int = 5): # Now let's create a version that's more adaptive to dark/light mode -def display_topk_token_predictions(sentence, original_logits, new_logits, tokenizer, k: int = 5): - """ - Version that tries to be more adaptive to both dark and light modes - using higher contrast elements and CSS variables where possible +def display_topk_token_predictions( + sentence, + original_logits, + new_logits, + tokenizer, + k: int = 5, + key_tokens: list[tuple[str, int]] | None = None, +): + """Display top-k token predictions before and after an intervention. + + Adaptive to both dark and light modes using higher-contrast elements + and CSS variables where possible. + + Args: + sentence: The input prompt string. + original_logits: Logits before the intervention. + new_logits: Logits after the intervention. + tokenizer: Tokenizer for decoding token IDs. + k: Number of top tokens to show per section. + key_tokens: Optional list of ``(token_label, token_id)`` pairs. + When provided, a third *Key Tokens* table is rendered showing + the probabilities of these specific tokens in both the original + and new distributions. """ original_tokens = get_topk(original_logits, tokenizer, k) @@ -186,6 +576,54 @@ def display_topk_token_predictions(sentence, original_logits, new_logits, tokeni + """ + + # Optional key-tokens section + if key_tokens: + orig_probs = torch.softmax(original_logits.squeeze()[-1], dim=-1) + new_probs = torch.softmax(new_logits.squeeze()[-1], dim=-1) + + html += """ +
+
Key Tokens
+ + + + + + + + + + + """ + for i, (label, tid) in enumerate(key_tokens): + p_orig = orig_probs[tid].item() + p_new = new_probs[tid].item() + relative = (p_new - p_orig) / max(p_orig, 1e-9) + sign = "+" if relative >= 0 else "" + bar_width = int(p_new / max(max_prob, 1e-9) * 100) + row_class = "even-row" if i % 2 == 0 else "odd-row" + html += f""" + + + + + + + """ + html += """ + +
TokenOriginalNewChange
{label}{p_orig:.4f}{p_new:.4f} +
+
+ {sign}{relative * 100:.1f}% +
+
+
+ """ + + html += """ """ diff --git a/circuit_tracer/utils/tl_nnsight_mapping.py b/circuit_tracer/utils/tl_nnsight_mapping.py index cae65397..e33fb5b7 100644 --- a/circuit_tracer/utils/tl_nnsight_mapping.py +++ b/circuit_tracer/utils/tl_nnsight_mapping.py @@ -241,6 +241,10 @@ def convert_nnsight_config_to_transformerlens(config): config: NNsight config object return_unified: If True, return UnifiedConfig instead of HookedTransformerConfig """ + # If already a UnifiedConfig, return as-is + if isinstance(config, UnifiedConfig): + return config + field_mappings = { # Basic model dimensions "num_hidden_layers": "n_layers", diff --git a/demos/attribute_demo.ipynb b/demos/attribute_demo.ipynb index b6739da8..cfc48a10 100644 --- a/demos/attribute_demo.ipynb +++ b/demos/attribute_demo.ipynb @@ -17,7 +17,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -44,7 +44,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": { "id": "P8fNhpqzmS8k" }, @@ -72,58 +72,11 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": { "id": "BBsETpl0mS8l" }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "ad0d25d11e5f4aacae3fd4fbba264d1e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Fetching 26 files: 0%| | 0/26 [00:00 **Coming up:** After these basic modes, we explore two **custom attribution target** examples that let you attribute back from arbitrary residual-stream directions — a logit *difference* (`logit(Austin) − logit(Dallas)`) and an abstract *semantic concept* (`Capitals − States`)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wx2XiXVjmS8l" + }, + "outputs": [], + "source": [ + "# Define the prompt, shared attribution parameters, and the three reference tokens (`▁Austin`, `▁Dallas`, `▁Texas`). \n", + "\n", + "prompt = \"Fact: the capital of the state containing Dallas is\"\n", + "token_x, token_y = \"▁Austin\", \"▁Dallas\"\n", + "\n", + "# Shared attribution kwargs (apply to all runs)\n", + "# Note: max_n_logits / desired_logit_prob only apply to salient-logit mode\n", + "attr_kwargs = dict(\n", + " batch_size=256,\n", + " max_feature_nodes=8192,\n", + " offload=\"disk\" if IN_COLAB else \"cpu\",\n", + " verbose=True,\n", + ")\n", + "\n", + "# Resolve token ids for key tokens\n", + "tokenizer = model.tokenizer\n", + "idx_x = tokenizer.encode(token_x, add_special_tokens=False)[-1]\n", + "idx_y = tokenizer.encode(token_y, add_special_tokens=False)[-1]\n", + "idx_texas = tokenizer.encode(\"▁Texas\", add_special_tokens=False)[-1]\n", + "\n", + "# Bind the tokenizer and key tokens for display helpers\n", + "display_topk = partial(\n", + " display_topk_token_predictions,\n", + " tokenizer=tokenizer,\n", + " key_tokens=[(token_x, idx_x), (token_y, idx_y), (\"▁Texas\", idx_texas)],\n", + ")\n", + "\n", + "# Show baseline token probabilities\n", + "input_ids = model.ensure_tokenized(prompt)\n", + "with torch.no_grad():\n", + " baseline_logits, _ = model.get_activations(input_ids)\n", + "\n", + "key_ids = [idx_x, idx_y, idx_texas]\n", + "key_labels = [token_x, token_y, \"▁Texas\"]\n", + "display_token_probs(baseline_logits, key_ids, key_labels, title=\"Baseline probabilities\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RUn1YKnUmS8l" + }, + "source": [ + "### Automatic Target Selection — Salient Logits (`None`)\n", + "\n", + "When `attribution_targets` is `None` (the default), `AttributionTargets` auto-selects the most probable next tokens until `desired_logit_prob` cumulative probability is reached (capped at `max_n_logits`). This is the standard mode used by `attribute_demo.ipynb`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2tLE4FzdmS8m" + }, + "outputs": [], + "source": [ + "graph_salient = attribute(\n", + " prompt=prompt, model=model,\n", + " max_n_logits=10, desired_logit_prob=0.95,\n", + " **attr_kwargs,\n", + ")\n", + "print(f\"Salient-logits graph: {len(graph_salient.logit_targets)} targets, \"\n", + " f\"{graph_salient.active_features.shape[0]} active features\")\n", + "\n", + "# Free CUDA memory before next run\n", + "cleanup_cuda()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "w3cdLLfJmS8m" + }, + "source": [ + "### Token-String Targets — `Sequence[str]`\n", + "\n", + "Pass a list of token strings (e.g., `[\"▁Austin\", \"▁Dallas\"]`) to focus attribution on exactly those logits. Internally, each string is tokenized and its softmax probability and unembedding vector are computed automatically — you only need to supply the surface forms." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Vh8HPtimmS8m" + }, + "outputs": [], + "source": [ + "graph_str = attribute(\n", + " prompt=prompt, model=model,\n", + " attribution_targets=[token_x, token_y],\n", + " **attr_kwargs,\n", + ")\n", + "print(f\"String-targets graph: {len(graph_str.logit_targets)} targets, \"\n", + " f\"{graph_str.active_features.shape[0]} active features\")\n", + "\n", + "# Free CUDA memory before next run\n", + "cleanup_cuda()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Token-ID Targets — `torch.Tensor`\n", + "\n", + "Pass a tensor of vocabulary token IDs to attribute from specific indices. This is the pre-tokenized equivalent of the string-target mode above — internally, the same probabilities and unembedding vectors are computed. Use this mode when you already have token IDs (e.g., from a prior tokenization step) and want to skip the string→ID lookup." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use the same token IDs as the string-target example above\n", + "tensor_targets = torch.tensor([idx_x, idx_y])\n", + "\n", + "graph_tensor = attribute(\n", + " prompt=prompt, model=model,\n", + " attribution_targets=tensor_targets,\n", + " **attr_kwargs,\n", + ")\n", + "print(f\"Tensor-targets graph: {len(graph_tensor.logit_targets)} targets, \"\n", + " f\"{graph_tensor.active_features.shape[0]} active features\")\n", + "\n", + "# Free CUDA memory before next run\n", + "cleanup_cuda()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Custom Attribution Targets\n", + "\n", + "Beyond the basic modes above, `AttributionTargets` also accepts a `Sequence[TargetSpec]` — fully specified custom targets that let you attribute toward **arbitrary directions** in the residual stream. This makes a vast experimental surface more accessible but we'll explore a couple examples in this tutorial:\n", + "\n", + "- **Logit Difference Target** — encodes the direction `logit(Austin) − logit(Dallas)`, surfacing features that drive the model to prefer one token *over* another rather than boosting either in isolation.\n", + "- **Semantic Concept Target** — encodes an abstract *Capitals − States* direction built from multiple (capital, state) pairs via vector rejection, isolating the *capital-of* relation from shared geography.\n", + "\n", + "See the expandable section below if you want a more detailed look at `CustomTarget` definition before we proceed with the examples below.\n", + "\n", + "
\n", + "TargetSpec / CustomTarget — field reference\n", + "\n", + "The `attribution_targets` argument to `attribute()` accepts a `Sequence[TargetSpec]` for fully custom residual-stream directions. Two convenience types are involved:\n", + "\n", + "**`CustomTarget(token_str, prob, vec)`** is a `NamedTuple` with three fields:\n", + "\n", + "| Field | Type | Description |\n", + "|---|---|---|\n", + "| `token_str` | `str` | Human-readable label for this target (e.g. `\"logit(Austin)−logit(Dallas)\"`) |\n", + "| `prob` | `float` | Scalar weight — typically the softmax probability of the token, or \\|p(x)−p(y)\\| for a contrast direction |\n", + "| `vec` | `Tensor (d_model,)` | The direction in residual-stream space to attribute toward |\n", + "\n", + "**`TargetSpec`** is a type alias for `CustomTarget | tuple[str, float, torch.Tensor]`. Either form is accepted — a raw 3-tuple is coerced to a `CustomTarget` namedtuple automatically before processing.\n", + "\n", + "**Example — raw tuple (coerced automatically):**\n", + "\n", + "```python\n", + "raw: TargetSpec = (\"my-direction\", 0.05, some_tensor) # plain 3-tuple → TargetSpec\n", + "graph = attribute(prompt=prompt, model=model, attribution_targets=[raw])\n", + "```\n", + "\n", + "**Example — explicit `CustomTarget` namedtuple:**\n", + "\n", + "```python\n", + "from circuit_tracer.attribution.targets import CustomTarget\n", + "\n", + "target = CustomTarget(\n", + " token_str=\"logit(Austin)−logit(Dallas)\",\n", + " prob=abs(p_austin - p_dallas), # scalar weight\n", + " vec=unembed_austin - unembed_dallas, # shape: (d_model,)\n", + ")\n", + "graph = attribute(prompt=prompt, model=model, attribution_targets=[target])\n", + "```\n", + "\n", + "
\n", + "\n", + "We first define two helper functions for building these custom targets, then construct and attribute from each one." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Target Builder Helpers\n", + "\n", + "We define two helper functions that each return a `CustomTarget` namedtuple.\n", + "\n", + "---\n", + "\n", + "**`build_custom_diff_target`** — *Logit-difference direction.*\n", + "\n", + "Subtracts the unembedding column of token $y$ from that of token $x$:\n", + "\n", + "$$\\mathbf{d} = \\mathbf{u}_{x} - \\mathbf{u}_{y}$$\n", + "\n", + "and weights the target by the absolute softmax-probability difference $|p(x) - p(y)|$. Attributing toward $\\mathbf{d}$ surfaces features that drive the model to prefer token $x$ *over* token $y$ — a narrower signal than boosting $x$ in isolation.\n", + "\n", + "---\n", + "\n", + "**`build_semantic_concept_target`** — *Abstract concept direction via vector rejection.*\n", + "\n", + "Given paired token groups $A$ (e.g. capital cities) and $B$ (e.g. states), the function strips the \"state\" component from each \"capital\" unembedding vector via orthogonal projection:\n", + "\n", + "$$\\mathbf{r}_i = \\mathbf{u}_{a_i} - \\frac{\\mathbf{u}_{a_i} \\cdot \\mathbf{u}_{b_i}}{\\|\\mathbf{u}_{b_i}\\|^2}\\,\\mathbf{u}_{b_i}$$\n", + "\n", + "The final concept direction is the mean of these residuals across all pairs:\n", + "\n", + "$$\\mathbf{d}_{\\text{concept}} = \\frac{1}{n}\\sum_{i=1}^{n} \\mathbf{r}_i$$\n", + "\n", + "**Intuition:** Raw capital-city vectors (Austin, Sacramento, …) are partially explained by their shared geography with their respective states (Texas, California, …). Projecting away the state component leaves a representation of *\"capital-ness\"* that is independent of specific geography. Attributing toward $\\mathbf{d}_{\\text{concept}}$ reveals features the model uses to execute the abstract *capital-of* relation in this context — a strictly more targeted lens than a single logit difference or token string target." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def _get_last_position_probs(model, prompt):\n", + " \"\"\"Get softmax probabilities at the last token position.\"\"\"\n", + " input_ids = model.ensure_tokenized(prompt)\n", + " with torch.no_grad():\n", + " logits, _ = model.get_activations(input_ids)\n", + " return torch.softmax(logits.squeeze(0)[-1], dim=-1)\n", + "\n", + "\n", + "def build_custom_diff_target(model, prompt, token_x, token_y, backend):\n", + " \"\"\"Build a CustomTarget for the direction logit(token_x) − logit(token_y).\n", + "\n", + " Returns (custom_target, idx_x, idx_y).\n", + " \"\"\"\n", + " tokenizer = model.tokenizer\n", + " idx_x = tokenizer.encode(token_x, add_special_tokens=False)[-1]\n", + " idx_y = tokenizer.encode(token_y, add_special_tokens=False)[-1]\n", + "\n", + " # Extract unembed columns\n", + " vec_x, vec_y = get_unembed_vecs(model, [idx_x, idx_y], backend)\n", + " diff_vec = vec_x - vec_y\n", + "\n", + " # Weight = |p(token_x) − p(token_y)|, floored at 1e-6\n", + " probs = _get_last_position_probs(model, prompt)\n", + " diff_prob = max((probs[idx_x] - probs[idx_y]).abs().item(), 1e-6)\n", + "\n", + " custom_target = CustomTarget(\n", + " token_str=f\"logit({token_x})-logit({token_y})\",\n", + " prob=diff_prob,\n", + " vec=diff_vec,\n", + " )\n", + " return custom_target, idx_x, idx_y\n", + "\n", + "\n", + "def build_semantic_concept_target(model, prompt, group_a_tokens, group_b_tokens, label, backend):\n", + " \"\"\"Build a CustomTarget for an abstract concept direction via vector rejection.\n", + "\n", + " For each (capital, state) pair, project the capital vector onto the state\n", + " vector and subtract that projection. The residual is the component of\n", + " \"capital-ness\" orthogonal to its state, stripping out shared geography.\n", + "\n", + " v_residual_i = v_cap_i − proj_{v_state_i}(v_cap_i)\n", + "\n", + " The final direction is the mean of these residuals.\n", + "\n", + " Returns CustomTarget.\n", + " \"\"\"\n", + " assert len(group_a_tokens) == len(group_b_tokens), \"Groups must have equal length for paired differences\"\n", + " tokenizer = model.tokenizer\n", + " ids_a = [tokenizer.encode(t, add_special_tokens=False)[-1] for t in group_a_tokens]\n", + " ids_b = [tokenizer.encode(t, add_special_tokens=False)[-1] for t in group_b_tokens]\n", + "\n", + " vecs_a = get_unembed_vecs(model, ids_a, backend)\n", + " vecs_b = get_unembed_vecs(model, ids_b, backend)\n", + "\n", + " # Vector rejection: for each pair, remove the state-direction component\n", + " residuals = []\n", + " for va, vb in zip(vecs_a, vecs_b):\n", + " va_f, vb_f = va.float(), vb.float()\n", + " proj = (va_f @ vb_f) / (vb_f @ vb_f) * vb_f # proj_{state}(capital)\n", + " residuals.append((va_f - proj).to(va.dtype))\n", + "\n", + " direction = torch.stack(residuals).mean(0)\n", + "\n", + " # Weight = average probability of group-A tokens, floored at 1e-6\n", + " probs = _get_last_position_probs(model, prompt)\n", + " avg_prob = max(sum(probs[i].item() for i in ids_a) / len(ids_a), 1e-6)\n", + "\n", + " return CustomTarget(token_str=label, prob=avg_prob, vec=direction)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Custom Target Configuration\n", + "\n", + "Build the two custom targets and display a summary of all attribution configurations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Build the custom diff-target: logit(Austin) − logit(Dallas)\n", + "custom_target, _, _ = build_custom_diff_target(\n", + " model, prompt, token_x, token_y, backend=backend\n", + ")\n", + "\n", + "# Build the semantic concept target: Capital Cities − States\n", + "capitals = [\"▁Austin\", \"▁Sacramento\", \"▁Olympia\", \"▁Atlanta\"]\n", + "states = [\"▁Texas\", \"▁California\", \"▁Washington\", \"▁Georgia\"]\n", + "semantic_target = build_semantic_concept_target(\n", + " model, prompt, capitals, states,\n", + " label=\"Capitals − States\", backend=backend,\n", + ")\n", + "\n", + "display_attribution_config(\n", + " token_pairs=[(token_x, idx_x), (token_y, idx_y), (\"▁Texas\", idx_texas)],\n", + " target_pairs=[(\"Logit diff\", custom_target), (\"Semantic concept\", semantic_target)],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EQuFE-eimS8m" + }, + "source": [ + "### Logit Difference Target\n", + "\n", + "Pass a `CustomTarget` (or any `TargetSpec` — a tuple of `(token_str, prob, vec)`) that encodes a contrast direction in the residual stream. Here the direction is `logit(Austin) − logit(Dallas)`, constructing an attribution graph that more narrowly surfaces features driving the selection of the *correct* answer over the surface-level attractor." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gMZ8Ee-KmS8m" + }, + "outputs": [], + "source": [ + "graph_custom = attribute(\n", + " prompt=prompt, model=model,\n", + " attribution_targets=[custom_target],\n", + " **attr_kwargs,\n", + ")\n", + "print(f\"Custom-target graph: {len(graph_custom.logit_targets)} targets, \"\n", + " f\"{graph_custom.active_features.shape[0]} active features\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Semantic Direction (Concept Target)\n", + "\n", + "Instead of a pairwise logit difference, we can attribute to an **abstract concept direction** in the residual stream. We build a `CustomTarget` via vector rejection: for each (capital, state) pair, project the capital vector onto the state vector and subtract that projection, leaving the pure 'capital-ness' component." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "graph_semantic = attribute(\n", + " prompt=prompt, model=model,\n", + " attribution_targets=[semantic_target],\n", + " **attr_kwargs,\n", + ")\n", + "print(f\"Semantic-target graph: {len(graph_semantic.logit_targets)} targets, \"\n", + " f\"{graph_semantic.active_features.shape[0]} active features\")\n", + "\n", + "# Free CUDA memory before feature comparison\n", + "cleanup_cuda()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yDGiO8jBmS8m" + }, + "source": [ + "## Compare Top Features\n", + "\n", + "Extract the top-10 features from each graph (ranked by multi-hop influence) and display them side by side. Feature indices link to their [Neuronpedia](https://www.neuronpedia.org/) dashboards. The *Custom Target* column highlights features that specifically drive the Austin-vs-Dallas logit difference — the multi-hop reasoning circuit (Dallas → Texas → capital → Austin). The *Concept Target* column surfaces features associated with the more general *capital-of* relation, which partially overlaps with the multi-hop chain but also includes distinct features that may reflect more abstract capital-related reasoning.\n", + "\n", + "> **Note:** The `torch.Tensor` target example is omitted from this comparison because it uses the same token IDs as the `Sequence[str]` example — the resulting graphs are identical." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "185O1Ck1mS8m" + }, + "outputs": [], + "source": [ + "top_salient, scores_salient = get_top_features(graph_salient, n=10)\n", + "top_str, scores_str = get_top_features(graph_str, n=10)\n", + "top_custom, scores_custom = get_top_features(graph_custom, n=10)\n", + "top_semantic, scores_semantic = get_top_features(graph_semantic, n=10)\n", + "\n", + "display_top_features_comparison(\n", + " {\n", + " \"Salient Logits\": top_salient,\n", + " f\"Strings [{token_x}, {token_y}]\": top_str,\n", + " f\"Custom Fn ({custom_target.token_str})\": top_custom,\n", + " f\"Semantic Concept ({semantic_target.token_str})\": top_semantic,\n", + " },\n", + " scores_sets={\n", + " \"Salient Logits\": scores_salient,\n", + " f\"Strings [{token_x}, {token_y}]\": scores_str,\n", + " f\"Custom Fn ({custom_target.token_str})\": scores_custom,\n", + " f\"Semantic Concept ({semantic_target.token_str})\": scores_semantic,\n", + " },\n", + " neuronpedia_model=\"gemma-2-2b\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Circuit Interventions\n", + "\n", + "Having identified the top features for each attribution mode example, we can now run interventions, manipulating the discovered features to bolster our credence in their hypothesized causal roles. We explore both amplification and ablation of the logit-difference and semantic concept circuits." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "98579UbGmS8m" + }, + "source": [ + "### Amplify the Austin-Dallas Logit Difference Circuit\n", + "\n", + "To confirm these custom-target features are causally meaningful, we amplify them by 10× and check that the Austin-vs-Dallas logit gap widens (i.e., the model becomes even more confident Austin is correct)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get activations for interventions\n", + "input_ids = model.ensure_tokenized(prompt)\n", + "original_logits, activations = model.get_activations(input_ids, sparse=True)\n", + "\n", + "# Baseline\n", + "display_token_probs(original_logits, key_ids, key_labels, title=\"Before amplification\")\n", + "\n", + "# Amplify top custom-target features by 10×\n", + "intervention_tuples = [\n", + " (layer, pos, feat_idx, 10.0 * activations[layer, pos, feat_idx])\n", + " for (layer, pos, feat_idx) in top_custom\n", + "]\n", + "\n", + "new_logits, _ = model.feature_intervention(input_ids, intervention_tuples)\n", + "\n", + "display_token_probs(new_logits, key_ids, key_labels, title=\"After 10× amplification\")\n", + "\n", + "orig_gap = (original_logits.squeeze(0)[-1, idx_x] - original_logits.squeeze(0)[-1, idx_y]).item()\n", + "new_gap = (new_logits.squeeze(0)[-1, idx_x] - new_logits.squeeze(0)[-1, idx_y]).item()\n", + "print(f\"\\nlogit(Austin) − logit(Dallas): {orig_gap:.4f} → {new_gap:.4f} (Δ = {new_gap - orig_gap:+.4f})\")\n", + "\n", + "display_topk(prompt, original_logits, new_logits)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Amplify the Semantic Concept Circuit\n", + "\n", + "Same amplification test for the **semantic concept** features. We compare a modest 2× boost (a gentle nudge along the concept axis) with a strong 10× boost to observe the difference in behaviour." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Baseline\n", + "display_token_probs(original_logits, key_ids, key_labels, title=\"Before amplification (semantic)\")\n", + "\n", + "orig_gap = (original_logits.squeeze(0)[-1, idx_x] - original_logits.squeeze(0)[-1, idx_y]).item()\n", + "\n", + "# --- 2× amplification (gentle nudge along the concept axis) ---\n", + "sem_amp_tuples_2 = [\n", + " (layer, pos, feat_idx, 2.0 * activations[layer, pos, feat_idx])\n", + " for (layer, pos, feat_idx) in top_semantic\n", + "]\n", + "\n", + "sem_amp_logits_2, _ = model.feature_intervention(input_ids, sem_amp_tuples_2)\n", + "\n", + "display_token_probs(sem_amp_logits_2, key_ids, key_labels, title=\"After 2× amplification (semantic)\")\n", + "\n", + "sem_gap_2 = (sem_amp_logits_2.squeeze(0)[-1, idx_x] - sem_amp_logits_2.squeeze(0)[-1, idx_y]).item()\n", + "print(f\"\\nlogit(Austin) − logit(Dallas): {orig_gap:.4f} → {sem_gap_2:.4f} (Δ = {sem_gap_2 - orig_gap:+.4f}) [2×]\")\n", + "\n", + "display_topk(prompt, original_logits, sem_amp_logits_2)\n", + "\n", + "# --- 10× amplification (strong boost) ---\n", + "sem_amp_tuples_10 = [\n", + " (layer, pos, feat_idx, 10.0 * activations[layer, pos, feat_idx])\n", + " for (layer, pos, feat_idx) in top_semantic\n", + "]\n", + "\n", + "sem_amp_logits_10, _ = model.feature_intervention(input_ids, sem_amp_tuples_10)\n", + "\n", + "display_token_probs(sem_amp_logits_10, key_ids, key_labels, title=\"After 10× amplification (semantic)\")\n", + "\n", + "sem_gap_10 = (sem_amp_logits_10.squeeze(0)[-1, idx_x] - sem_amp_logits_10.squeeze(0)[-1, idx_y]).item()\n", + "print(f\"\\nlogit(Austin) − logit(Dallas): {orig_gap:.4f} → {sem_gap_10:.4f} (Δ = {sem_gap_10 - orig_gap:+.4f}) [10×]\")\n", + "\n", + "display_topk(prompt, original_logits, sem_amp_logits_10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Ablate the Austin-Dallas Logit Difference Circuit\n", + "\n", + "Now we do the opposite: zero out progressively more features important to our custom target to dampen the Austin-driving circuit. With enough of the multi-hop reasoning path suppressed, the model can no longer resolve the correct answer and reverts to nearby concepts — e.g. the intermediate state (Texas) rather than its capital." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import display, Markdown\n", + "\n", + "# Progressive ablation: zero out increasing numbers of custom-target features\n", + "probs_base = torch.softmax(original_logits.squeeze(0)[-1].float(), dim=-1)\n", + "groups = {\"baseline\": {\n", + " \"P(Austin)\": probs_base[idx_x].item(),\n", + " \"P(Dallas)\": probs_base[idx_y].item(),\n", + " \"P(Texas)\": probs_base[idx_texas].item(),\n", + "}}\n", + "logit_diffs = {\"baseline\": orig_gap}\n", + "\n", + "ablation_results = {}\n", + "for n in [10, 100]:\n", + " top_n, _ = get_top_features(graph_custom, n=n)\n", + " abl_tuples = [\n", + " (layer, pos, feat_idx, 0.0 * activations[layer, pos, feat_idx])\n", + " for (layer, pos, feat_idx) in top_n\n", + " ]\n", + " abl_logits, _ = model.feature_intervention(input_ids, abl_tuples)\n", + " probs_abl = torch.softmax(abl_logits.squeeze(0)[-1].float(), dim=-1)\n", + " gap = (abl_logits.squeeze(0)[-1, idx_x] - abl_logits.squeeze(0)[-1, idx_y]).item()\n", + " label = f\"top-{n}\"\n", + " groups[label] = {\n", + " \"P(Austin)\": probs_abl[idx_x].item(),\n", + " \"P(Dallas)\": probs_abl[idx_y].item(),\n", + " \"P(Texas)\": probs_abl[idx_texas].item(),\n", + " }\n", + " logit_diffs[label] = gap\n", + " ablation_results[n] = abl_logits\n", + "\n", + "display_ablation_chart(groups, logit_diffs=logit_diffs,\n", + " title=\"Custom-target ablation: token probabilities & logit gap\")\n", + "\n", + "# Show the full top-k comparison for the strongest ablation\n", + "strongest_n = max(ablation_results.keys())\n", + "display(Markdown(f\"#### Top-{strongest_n} ablation — full prediction shift\"))\n", + "display_topk(prompt, original_logits, ablation_results[strongest_n])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Ablate the Semantic Concept Circuit\n", + "\n", + "Same progressive ablation, now zeroing out features from the **semantic concept** graph. Because the concept direction captures the capital-vs-state pathway, ablation should similarly collapse the Austin signal." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import display, Markdown\n", + "\n", + "# Progressive ablation of semantic-target features\n", + "sem_groups = {\"baseline\": {\n", + " \"P(Austin)\": probs_base[idx_x].item(),\n", + " \"P(Dallas)\": probs_base[idx_y].item(),\n", + " \"P(Texas)\": probs_base[idx_texas].item(),\n", + "}}\n", + "sem_logit_diffs = {\"baseline\": orig_gap}\n", + "\n", + "sem_ablation_results = {}\n", + "for n in [10, 100]:\n", + " top_n, _ = get_top_features(graph_semantic, n=n)\n", + " abl_tuples = [\n", + " (layer, pos, feat_idx, 0.0 * activations[layer, pos, feat_idx])\n", + " for (layer, pos, feat_idx) in top_n\n", + " ]\n", + " abl_logits, _ = model.feature_intervention(input_ids, abl_tuples)\n", + " probs_abl = torch.softmax(abl_logits.squeeze(0)[-1].float(), dim=-1)\n", + " gap = (abl_logits.squeeze(0)[-1, idx_x] - abl_logits.squeeze(0)[-1, idx_y]).item()\n", + " label = f\"top-{n}\"\n", + " sem_groups[label] = {\n", + " \"P(Austin)\": probs_abl[idx_x].item(),\n", + " \"P(Dallas)\": probs_abl[idx_y].item(),\n", + " \"P(Texas)\": probs_abl[idx_texas].item(),\n", + " }\n", + " sem_logit_diffs[label] = gap\n", + " sem_ablation_results[n] = abl_logits\n", + "\n", + "display_ablation_chart(sem_groups, logit_diffs=sem_logit_diffs,\n", + " title=\"Semantic-target ablation: token probabilities & logit gap\")\n", + "\n", + "# Show the full top-k comparison for the strongest ablation\n", + "strongest_n = max(sem_ablation_results.keys())\n", + "display(Markdown(f\"#### Top-{strongest_n} semantic ablation — full prediction shift\"))\n", + "display_topk(prompt, original_logits, sem_ablation_results[strongest_n])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IGnU9l1zmS8m" + }, + "source": [ + "## Visualize the Semantic Concept Graph\n", + "\n", + "Save the **semantic concept** graph and serve it locally. The interactive visualization shows the circuit driving the abstract `Capitals − States` direction — the multi-hop reasoning path.\n", + "\n", + "**If running on a remote server, set up port forwarding so that port 8046 is accessible on your local machine.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "graph_dir = Path(\"attribution_targets_demo/graphs\")\n", + "graph_dir.mkdir(parents=True, exist_ok=True)\n", + "graph_path = graph_dir / \"dallas_austin_semantic_concept_graph.pt\"\n", + "graph_semantic.to_pt(graph_path)\n", + "\n", + "slug = \"dallas-austin-semantic-concept\"\n", + "graph_file_dir = \"attribution_targets_demo/graph_files\"\n", + "node_threshold, edge_threshold = 0.8, 0.98\n", + "\n", + "create_graph_files(\n", + " graph_or_path=graph_path,\n", + " slug=slug,\n", + " output_path=graph_file_dir,\n", + " node_threshold=node_threshold,\n", + " edge_threshold=edge_threshold,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GmKhWpuUmS8n" + }, + "outputs": [], + "source": [ + "from circuit_tracer.frontend.local_server import serve\n", + "\n", + "port = 8046\n", + "server = serve(data_dir=\"attribution_targets_demo/graph_files/\", port=port)\n", + "\n", + "if IN_COLAB:\n", + " from google.colab import output as colab_output # noqa\n", + " colab_output.serve_kernel_port_as_iframe(\n", + " port, path=\"/index.html\", height=\"800px\", cache_in_notebook=True\n", + " )\n", + "else:\n", + " from IPython.display import IFrame\n", + " print(f\"Open your graph at: http://localhost:{port}/index.html\")\n", + " display(IFrame(src=f\"http://localhost:{port}/index.html\", width=\"100%\", height=\"800px\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "uCo4FSQwqcBl" + }, + "outputs": [], + "source": [ + "# server.stop()" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "ct_dev (3.13.11)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.11" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/demos/img/attribution_targets/attribution_targets_banner.png b/demos/img/attribution_targets/attribution_targets_banner.png new file mode 100644 index 00000000..c256b3f4 Binary files /dev/null and b/demos/img/attribution_targets/attribution_targets_banner.png differ diff --git a/tests/conftest.py b/tests/conftest.py index 67cef3cf..d7088353 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,9 @@ import pytest import torch +# Check for 32GB+ VRAM once at module load time +has_32gb = torch.cuda.is_available() and torch.cuda.get_device_properties(0).total_memory > 32e9 + @pytest.fixture(autouse=True) def set_torch_seed() -> None: diff --git a/tests/test_attribution_targets.py b/tests/test_attribution_targets.py new file mode 100644 index 00000000..4e5855e3 --- /dev/null +++ b/tests/test_attribution_targets.py @@ -0,0 +1,902 @@ +"""Tests for AttributionTargets class.""" + +import gc +from collections.abc import Sequence +from typing import cast + +import torch +import pytest + +from circuit_tracer import Graph, ReplacementModel +from circuit_tracer.attribution.attribute import attribute +from circuit_tracer.attribution.targets import AttributionTargets, CustomTarget, LogitTarget + + +class MockTokenizer: + """Mock tokenizer for testing. + + This tokenizer supports bijective encode/decode for strings of the form + ``"tok_"`` so that roundtrip consistency tests work correctly. + """ + + vocab_size = 100 # Define vocab size for testing + + def encode(self, text, add_special_tokens=False): + # Simple mock: return token indices within valid range (0-99) + if not text: + return [] + # Support roundtrip: if text is "tok_", decode back to N + if text.startswith("tok_"): + try: + idx = int(text[4:]) + if 0 <= idx < self.vocab_size: + return [idx] + except ValueError: + pass + # Fallback: use hash to generate consistent indices within range + return [hash(text) % self.vocab_size] + + def decode(self, token_id): + """Decode a single token ID to a string.""" + if isinstance(token_id, int): + return f"tok_{token_id}" + return str(token_id) + + +@pytest.fixture +def mock_data(): + """Create mock logits and unembedding projection.""" + vocab_size = 100 + d_model = 64 + + # Create reproducible random data + torch.manual_seed(42) + logits = torch.randn(vocab_size) + unembed_proj = torch.randn(d_model, vocab_size) + tokenizer = MockTokenizer() + + return logits, unembed_proj, tokenizer + + +# === Sequence[str] mode tests === + + +def test_attribution_targets_str_list(mock_data): + """Test AttributionTargets with Sequence[str] input (list).""" + logits, unembed_proj, tokenizer = mock_data + targets = AttributionTargets( + attribution_targets=["hello", "world", "test"], + logits=logits, + unembed_proj=unembed_proj, + tokenizer=tokenizer, + ) + + assert len(targets) == 3 + assert all(isinstance(t, LogitTarget) for t in targets.logit_targets) + assert targets.logit_probabilities.shape == (3,) + assert targets.logit_vectors.shape == (3, 64) + # All should have real vocab indices + assert all(t.vocab_idx < tokenizer.vocab_size for t in targets.logit_targets) + # token_ids should work (all real indices) + token_ids = targets.token_ids + assert token_ids.shape == (3,) + # tokens property should return decoded strings + assert all(len(t) > 0 for t in targets.tokens) + + +# === Sequence[TargetSpec] mode tests === + + +@pytest.mark.parametrize( + "target_tuples,expected_keys", + [ + ( + [ + ("token1", 0.4, torch.randn(64)), + ("token2", 0.3, torch.randn(64)), + ("token3", 0.3, torch.randn(64)), + ], + ["token1", "token2", "token3"], + ), + ], + ids=["all_tuples"], +) +def test_attribution_targets_tuple_list(mock_data, target_tuples, expected_keys): + """Test AttributionTargets with Sequence[tuple[str, float, Tensor]] input.""" + logits, unembed_proj, tokenizer = mock_data + targets = AttributionTargets( + attribution_targets=target_tuples, + logits=logits, + unembed_proj=unembed_proj, + tokenizer=tokenizer, + ) + + assert len(targets) == len(expected_keys) + assert all(isinstance(t, LogitTarget) for t in targets.logit_targets) + # Tuple targets get virtual indices + assert all(t.vocab_idx >= tokenizer.vocab_size for t in targets.logit_targets) + # Check token_str matches expected keys + for i, expected_key in enumerate(expected_keys): + assert targets.logit_targets[i].token_str == expected_key + assert torch.allclose(targets.logit_probabilities, torch.tensor([0.4, 0.3, 0.3])) + + +def test_attribution_targets_custom_target_namedtuple(mock_data): + """Test AttributionTargets with Sequence[CustomTarget] input.""" + logits, unembed_proj, tokenizer = mock_data + + custom_targets = [ + CustomTarget(token_str="target_a", prob=0.6, vec=torch.randn(64)), + CustomTarget(token_str="target_b", prob=0.4, vec=torch.randn(64)), + ] + targets = AttributionTargets( + attribution_targets=custom_targets, + logits=logits, + unembed_proj=unembed_proj, + tokenizer=tokenizer, + ) + + assert len(targets) == 2 + assert all(isinstance(t, LogitTarget) for t in targets.logit_targets) + # CustomTarget targets get virtual indices + assert all(t.vocab_idx >= tokenizer.vocab_size for t in targets.logit_targets) + assert targets.logit_targets[0].token_str == "target_a" + assert targets.logit_targets[1].token_str == "target_b" + assert torch.allclose(targets.logit_probabilities, torch.tensor([0.6, 0.4])) + + +# === Auto modes (None and Tensor) === + + +@pytest.mark.parametrize( + "attribution_targets,max_n_logits,desired_prob,test_id", + [ + (None, 5, 0.8, "salient"), + (torch.tensor([5, 10, 15]), None, None, "specific_indices"), + ], + ids=["salient", "specific_indices"], +) +def test_attribution_targets_auto_modes( + mock_data, attribution_targets, max_n_logits, desired_prob, test_id +): + """Test AttributionTargets with automatic modes (None and Tensor).""" + logits, unembed_proj, tokenizer = mock_data + + kwargs = {} + if max_n_logits is not None: + kwargs["max_n_logits"] = max_n_logits + if desired_prob is not None: + kwargs["desired_logit_prob"] = desired_prob + + targets = AttributionTargets( + attribution_targets=attribution_targets, + logits=logits, + unembed_proj=unembed_proj, + tokenizer=tokenizer, + **kwargs, + ) + + assert isinstance(targets.logit_targets, list) + assert all(isinstance(t, LogitTarget) for t in targets.logit_targets) + assert all(t.vocab_idx < tokenizer.vocab_size for t in targets.logit_targets) + + if test_id == "salient": + assert len(targets) <= max_n_logits + assert len(targets) >= 1 + prob_sum = targets.logit_probabilities.sum().item() + assert prob_sum >= desired_prob or len(targets) == max_n_logits + elif test_id == "specific_indices": + assert [t.vocab_idx for t in targets.logit_targets] == [5, 10, 15] + assert targets.logit_probabilities.shape == (3,) + assert targets.logit_vectors.shape == (3, 64) + + +# === Error handling === + + +@pytest.mark.parametrize( + "targets_input,error_type,error_match", + [ + ( + [("token", 0.5)], # Only 2 elements, should be 3 + ValueError, + "exactly 3 elements", + ), + ( + [(5, 0.5, torch.randn(64))], # int instead of str + TypeError, + "Custom target token_str must be str", + ), + ( + [], # Empty list + ValueError, + "cannot be empty", + ), + ( + [42], # int in list (no longer supported) + TypeError, + "Sequence elements must be str or TargetSpec", + ), + ( + torch.tensor([5, 105, 10]), # Tensor with out of range + ValueError, + "Token indices must be in range", + ), + ], + ids=[ + "invalid_tuple_length", + "invalid_tuple_token_type", + "empty_list", + "int_in_list_rejected", + "tensor_out_of_range", + ], +) +def test_attribution_targets_errors(mock_data, targets_input, error_type, error_match): + """Test AttributionTargets error handling.""" + logits, unembed_proj, tokenizer = mock_data + + with pytest.raises(error_type, match=error_match): + AttributionTargets( + attribution_targets=targets_input, # type: ignore + logits=logits, + unembed_proj=unembed_proj, + tokenizer=tokenizer, + ) + + +# === Consistency tests === + + +def test_attribution_targets_str_list_consistency(mock_data): + """Test that the same string list inputs produce consistent results.""" + logits, unembed_proj, tokenizer = mock_data + + targets1 = AttributionTargets( + attribution_targets=["hello", "world"], + logits=logits, + unembed_proj=unembed_proj, + tokenizer=tokenizer, + ) + targets2 = AttributionTargets( + attribution_targets=["hello", "world"], + logits=logits, + unembed_proj=unembed_proj, + tokenizer=tokenizer, + ) + assert targets1.logit_targets == targets2.logit_targets + assert torch.equal(targets1.logit_probabilities, targets2.logit_probabilities) + assert torch.equal(targets1.logit_vectors, targets2.logit_vectors) + + +def test_attribution_targets_none_vs_str_list_consistency(mock_data): + """Test that None (auto-select) and equivalent Sequence[str] produce same results. + + Runs with None to auto-select salient logits, then constructs equivalent + Sequence[str] from the auto-selected token strings and verifies consistency. + """ + logits, unembed_proj, tokenizer = mock_data + + # Auto-select + targets_auto = AttributionTargets( + attribution_targets=None, + logits=logits, + unembed_proj=unembed_proj, + tokenizer=tokenizer, + max_n_logits=5, + desired_logit_prob=0.8, + ) + + # Reconstruct using the auto-selected token strings + auto_token_strs = targets_auto.tokens + targets_explicit = AttributionTargets( + attribution_targets=auto_token_strs, + logits=logits, + unembed_proj=unembed_proj, + tokenizer=tokenizer, + ) + + # Same logit targets + assert targets_auto.logit_targets == targets_explicit.logit_targets + # Same probabilities + assert torch.allclose(targets_auto.logit_probabilities, targets_explicit.logit_probabilities) + # Same vectors + assert torch.allclose(targets_auto.logit_vectors, targets_explicit.logit_vectors) + + +def test_attribution_targets_none_vs_tuple_list_consistency(mock_data): + """Test that None and equivalent Sequence[TargetSpec] produce same results. + + Auto-selects, then constructs equivalent Sequence[TargetSpec] with the same + probabilities and vectors, and verifies consistency. + """ + logits, unembed_proj, tokenizer = mock_data + + # Auto-select + targets_auto = AttributionTargets( + attribution_targets=None, + logits=logits, + unembed_proj=unembed_proj, + tokenizer=tokenizer, + max_n_logits=3, + desired_logit_prob=0.5, + ) + + # Reconstruct as tuple list with same probs and vecs + tuple_targets = [ + (tok, prob.item(), vec) + for tok, prob, vec in zip( + targets_auto.tokens, + targets_auto.logit_probabilities, + targets_auto.logit_vectors, + ) + ] + targets_tuple = AttributionTargets( + attribution_targets=tuple_targets, + logits=logits, + unembed_proj=unembed_proj, + tokenizer=tokenizer, + ) + + # Same probabilities + assert torch.allclose(targets_auto.logit_probabilities, targets_tuple.logit_probabilities) + # Same vectors + assert torch.allclose(targets_auto.logit_vectors, targets_tuple.logit_vectors) + # Same token strings + assert targets_auto.tokens == targets_tuple.tokens + + +# === Tuple (non-list Sequence) input tests === + + +def test_attribution_targets_tuple_of_strs(mock_data): + """Test AttributionTargets accepts tuple[str, ...] as Sequence[str] input.""" + logits, unembed_proj, tokenizer = mock_data + targets = AttributionTargets( + attribution_targets=("hello", "world", "test"), + logits=logits, + unembed_proj=unembed_proj, + tokenizer=tokenizer, + ) + + assert len(targets) == 3 + assert all(isinstance(t, LogitTarget) for t in targets.logit_targets) + assert targets.logit_probabilities.shape == (3,) + assert targets.logit_vectors.shape == (3, 64) + assert all(t.vocab_idx < tokenizer.vocab_size for t in targets.logit_targets) + + +def test_attribution_targets_tuple_of_target_specs(mock_data): + """Test AttributionTargets accepts tuple[TargetSpec, ...] as Sequence[TargetSpec] input.""" + logits, unembed_proj, tokenizer = mock_data + ct1 = CustomTarget(token_str="alpha", prob=0.6, vec=torch.randn(64)) + ct2 = CustomTarget(token_str="beta", prob=0.4, vec=torch.randn(64)) + targets = AttributionTargets( + attribution_targets=(ct1, ct2), + logits=logits, + unembed_proj=unembed_proj, + tokenizer=tokenizer, + ) + + assert len(targets) == 2 + assert all(isinstance(t, LogitTarget) for t in targets.logit_targets) + assert targets.logit_targets[0].token_str == "alpha" + assert targets.logit_targets[1].token_str == "beta" + assert torch.allclose(targets.logit_probabilities, torch.tensor([0.6, 0.4])) + + +# === Property and utility tests === + + +def test_attribution_targets_tokens_property(mock_data): + """Test tokens property returns correct strings for tuple targets.""" + logits, unembed_proj, tokenizer = mock_data + + targets = AttributionTargets( + attribution_targets=[ + ("arbitrary", 0.5, torch.randn(64)), + ("custom_func", 0.3, torch.randn(64)), + ], + logits=logits, + unembed_proj=unembed_proj, + tokenizer=tokenizer, + ) + + tokens = targets.tokens + assert tokens == ["arbitrary", "custom_func"] + + +def test_attribution_targets_virtual_token_ids(mock_data): + """Test token_ids property for tuple targets (virtual indices).""" + logits, unembed_proj, tokenizer = mock_data + vocab_size = tokenizer.vocab_size + + targets = AttributionTargets( + attribution_targets=[ + ("t1", 0.3, torch.randn(64)), + ("t2", 0.4, torch.randn(64)), + ("t3", 0.3, torch.randn(64)), + ], + logits=logits, + unembed_proj=unembed_proj, + tokenizer=tokenizer, + ) + + expected = [vocab_size + 0, vocab_size + 1, vocab_size + 2] + assert targets.token_ids.tolist() == expected + + +def test_attribution_targets_token_ids_real(mock_data): + """Test token_ids property for real vocab indices (str list and tensor).""" + logits, unembed_proj, tokenizer = mock_data + + # Tensor input + targets = AttributionTargets( + attribution_targets=torch.tensor([5, 10, 15]), + logits=logits, + unembed_proj=unembed_proj, + tokenizer=tokenizer, + ) + token_ids = targets.token_ids + assert torch.equal(token_ids, torch.tensor([5, 10, 15], dtype=torch.long)) + + +@pytest.mark.parametrize( + "test_method,expected_value", + [ + ("to_device", "cpu"), + ("repr", "AttributionTargets"), + ("len", 3), + ], + ids=["to_device", "repr", "len"], +) +def test_attribution_targets_utility_methods(mock_data, test_method, expected_value): + """Test utility methods: to(), __repr__(), and __len__().""" + logits, unembed_proj, tokenizer = mock_data + + targets = AttributionTargets( + attribution_targets=["a", "b", "c"], + logits=logits, + unembed_proj=unembed_proj, + tokenizer=tokenizer, + ) + + if test_method == "to_device": + targets_cpu = targets.to("cpu") + assert isinstance(targets_cpu, AttributionTargets) + assert targets_cpu.logit_probabilities.device.type == expected_value + assert targets_cpu.logit_vectors.device.type == expected_value + assert targets_cpu.tokenizer is tokenizer + elif test_method == "repr": + repr_str = repr(targets) + assert "AttributionTargets" in repr_str + assert "n=3" in repr_str + elif test_method == "len": + assert len(targets) == expected_value + + +# === Multi-token encoding tests === + + +def test_attribution_targets_multi_token_error(mock_data): + """Test that multi-token strings raise a ValueError.""" + logits, unembed_proj, tokenizer = mock_data + + # Mock tokenizer to return multi-token encoding for a specific string + original_encode = tokenizer.encode + + def multi_token_encode(text, add_special_tokens=False): + if text == "multi_token_string": + return [10, 20, 30] # Three tokens + return original_encode(text, add_special_tokens) + + tokenizer.encode = multi_token_encode + + with pytest.raises(ValueError, match="encoded to 3 tokens"): + AttributionTargets( + attribution_targets=["multi_token_string"], + logits=logits, + unembed_proj=unembed_proj, + tokenizer=tokenizer, + ) + + # Restore original encode + tokenizer.encode = original_encode + + +# === Type validation === + + +def test_attribution_targets_tuple_invalid_prob_type(mock_data): + """Test that invalid prob type raises TypeError.""" + logits, unembed_proj, tokenizer = mock_data + + with pytest.raises(TypeError, match="Custom target prob must be int or float"): + from circuit_tracer.attribution.targets import TargetSpec + + invalid_targets = cast( + Sequence[TargetSpec], + [ + ( + "token1", + "0.5", + torch.randn(64), + ), # String instead of float - intentionally invalid + ], + ) + AttributionTargets( + attribution_targets=invalid_targets, + logits=logits, + unembed_proj=unembed_proj, + tokenizer=tokenizer, + ) + + +def test_attribution_targets_tuple_invalid_vec_type(mock_data): + """Test that invalid vec type raises TypeError.""" + logits, unembed_proj, tokenizer = mock_data + + with pytest.raises(TypeError, match="Custom target vec must be torch.Tensor"): + from circuit_tracer.attribution.targets import TargetSpec + + invalid_targets = cast( + Sequence[TargetSpec], + [ + ("token1", 0.5, [1.0, 2.0, 3.0]), # List instead of Tensor - intentionally invalid + ], + ) + AttributionTargets( + attribution_targets=invalid_targets, + logits=logits, + unembed_proj=unembed_proj, + tokenizer=tokenizer, + ) + + +def test_attribution_targets_tuple_valid_int_prob(mock_data): + """Test that int probability is accepted (not just float).""" + logits, unembed_proj, tokenizer = mock_data + + targets = AttributionTargets( + attribution_targets=[ + ("token1", 1, torch.randn(64)), # Int probability + ], + logits=logits, + unembed_proj=unembed_proj, + tokenizer=tokenizer, + ) + + assert len(targets) == 1 + assert targets.logit_probabilities[0].item() == 1.0 + + +# ============================================================================= +# Integration tests: custom target correctness & format consistency +# ============================================================================= + +# === Shared helpers for integration tests === + + +def _get_top_features(graph: Graph, n: int = 10) -> list[tuple[int, int, int]]: + """Extract the top-N feature nodes from the graph based on attribution scores. + + Returns list of (layer, pos, feature_idx) tuples. + """ + error_node_offset = graph.active_features.shape[0] + _, first_order_indices = torch.topk(graph.adjacency_matrix[-1, :error_node_offset], n) + top_features = [tuple(x) for x in graph.active_features[first_order_indices].tolist()] + return top_features + + +def _get_unembed_weights(model, backend: str): + """Helper to get unembedding weights in a backend-agnostic way.""" + if backend == "transformerlens": + return model.unembed.W_U # (d_model, d_vocab) + else: + return model.unembed_weight # (d_vocab, d_model) for NNSight + + +def _build_custom_diff_target( + model, prompt: str, token_x: str, token_y: str, backend: str +) -> tuple[CustomTarget, int, int]: + """Build a CustomTarget representing logit(x) - logit(y) from the model's unembed matrix. + + Returns: + Tuple of (custom_target, idx_x, idx_y) where idx_x and idx_y are + the token indices for x and y respectively. + """ + tokenizer = model.tokenizer + idx_x = tokenizer.encode(token_x, add_special_tokens=False)[-1] + idx_y = tokenizer.encode(token_y, add_special_tokens=False)[-1] + + input_ids = model.ensure_tokenized(prompt) + with torch.no_grad(): + logits, _ = model.get_activations(input_ids) + last_logits = logits.squeeze(0)[-1] # (d_vocab,) + + # Auto-detect matrix orientation by matching against vocabulary size + d_vocab = tokenizer.vocab_size + unembed = _get_unembed_weights(model, backend) + if unembed.shape[0] == d_vocab: + vec_x = unembed[idx_x] # (d_model,) + vec_y = unembed[idx_y] # (d_model,) + else: + # Shape is (d_model, d_vocab) – second axis is vocabulary (e.g., TransformerLens) + vec_x = unembed[:, idx_x] # (d_model,) + vec_y = unembed[:, idx_y] # (d_model,) + + diff_vec = vec_x - vec_y + # Use the absolute difference in softmax probabilities as weight + probs = torch.softmax(last_logits, dim=-1) + diff_prob = (probs[idx_x] - probs[idx_y]).abs().item() + if diff_prob < 1e-6: + diff_prob = 0.5 # fallback weight if probs are nearly equal + + custom_target = CustomTarget( + token_str=f"logit({token_x})-logit({token_y})", + prob=diff_prob, + vec=diff_vec, + ) + return custom_target, idx_x, idx_y + + +def _cfg_backend(backend: str): + """Return (model, n_layers_range, unembed_proj) for the given backend.""" + if backend == "transformerlens": + model = ReplacementModel.from_pretrained("google/gemma-2-2b", "gemma") + n_layers_range = range(model.cfg.n_layers) # type: ignore + unembed_proj = model.unembed.W_U + else: + model = ReplacementModel.from_pretrained("google/gemma-2-2b", "gemma", backend="nnsight") + n_layers_range = range(model.config.num_hidden_layers) # type: ignore + unembed_proj = model.unembed_weight + return model, n_layers_range, unembed_proj + + +def _run_attribution_format_consistency(backend: str): + """Backend-agnostic logic for attribution target format consistency test. + + Runs attribution with None (auto-select), then constructs equivalent Sequence[str] + and Sequence[CustomTarget] from the auto-selected targets and verifies consistency. + """ + prompt = "Entropy spares no entity" + + model, _, unembed_proj = _cfg_backend(backend) + + # Run with None (auto-select salient logits) + graph_none = attribute(prompt, model, attribution_targets=None, max_n_logits=5, batch_size=256) + + # Extract the auto-selected token strings and their internal data + auto_token_strs = [t.token_str for t in graph_none.logit_targets] + + # Run with Sequence[str] using the same token strings + graph_str = attribute(prompt, model, attribution_targets=auto_token_strs, batch_size=256) + + # Run with Sequence[CustomTarget] using the same tokens, probs, and vectors + # Reconstruct the unembed vectors for each auto-selected token + input_ids = model.ensure_tokenized(prompt) + with torch.no_grad(): + logits, _ = model.get_activations(input_ids) + last_logits = logits.squeeze(0)[-1] + + # Build the same AttributionTargets that _from_salient would produce to extract the exact vectors + assert isinstance(unembed_proj, torch.Tensor) + auto_targets_obj = AttributionTargets( + attribution_targets=None, + logits=last_logits, + unembed_proj=unembed_proj, + tokenizer=model.tokenizer, + max_n_logits=5, + desired_logit_prob=0.8, + ) + + custom_targets = [ + CustomTarget(token_str=tok, prob=prob.item(), vec=vec) + for tok, prob, vec in zip( + auto_targets_obj.tokens, + auto_targets_obj.logit_probabilities, + auto_targets_obj.logit_vectors, + ) + ] + + graph_tuple = attribute(prompt, model, attribution_targets=custom_targets, batch_size=256) + + # Verify consistency between None and Sequence[str] + # Same number of targets + assert len(graph_none.logit_targets) == len(graph_str.logit_targets), ( + f"None ({len(graph_none.logit_targets)}) vs str ({len(graph_str.logit_targets)}) " + f"target count mismatch" + ) + + # Same token strings + none_tokens = [t.token_str for t in graph_none.logit_targets] + str_tokens = [t.token_str for t in graph_str.logit_targets] + assert none_tokens == str_tokens, f"Token strings differ: {none_tokens} vs {str_tokens}" + + # Same probabilities (within tolerance) + assert torch.allclose( + graph_none.logit_probabilities, + graph_str.logit_probabilities, + atol=1e-6, + ), "Probabilities differ between None and Sequence[str] modes" + + # Same adjacency matrix (within tolerance) + assert torch.allclose( + graph_none.adjacency_matrix, + graph_str.adjacency_matrix, + atol=1e-5, + rtol=1e-4, + ), "Adjacency matrices differ between None and Sequence[str] modes" + + # Verify consistency between None and Sequence[CustomTarget] + assert len(graph_none.logit_targets) == len(graph_tuple.logit_targets), ( + f"None ({len(graph_none.logit_targets)}) vs tuple ({len(graph_tuple.logit_targets)}) " + f"target count mismatch" + ) + + # Token strings should match + tuple_tokens = [t.token_str for t in graph_tuple.logit_targets] + assert none_tokens == tuple_tokens, f"Token strings differ: {none_tokens} vs {tuple_tokens}" + + # Probabilities should match + assert torch.allclose( + graph_none.logit_probabilities, + graph_tuple.logit_probabilities.to(graph_none.logit_probabilities.device), + atol=1e-6, + ), "Probabilities differ between None and Sequence[CustomTarget] modes" + + # Adjacency matrices should match (tuple targets use the same unembed vecs) + assert torch.allclose( + graph_none.adjacency_matrix, + graph_tuple.adjacency_matrix.to(graph_none.adjacency_matrix.device), + atol=1e-5, + rtol=1e-4, + ), "Adjacency matrices differ between None and Sequence[CustomTarget] modes" + + +def _run_custom_target_correctness( + backend: str, + n_samples: int = 20, + act_atol: float = 5e-4, + act_rtol: float = 1e-5, + logit_atol: float = 1e-4, + logit_rtol: float = 1e-3, +): + """Verify custom target direction feature attribution driven interventions produce expected activation/logit changes + + For a ``logit(x) − logit(y)`` custom direction, randomly samples features, doubles each feature's pre-activation + value (under constrained/frozen-layer conditions), and checks that both the activation changes and the custom + logit-difference change match the adjacency matrix predictions within acceptable tolerances. + + * **Activation changes** match ``adjacency_matrix[:n_features, node]`` within act_atol/act_rtol. + * **Custom logit-difference change** matches the adjacency logit-node prediction within logit_atol/logit_rtol. + + We use the same linear-regime conditions as our other attribution validation tests, e.g. ``verify_feature_edges``: + + * ``constrained_layers=range(n_layers)`` — freezes all layer norms, MLPs, and attention, preventing non-linear + propagation. + * ``apply_activation_function=False`` — operates on pre-activation values. + * ``model.zero_softcap()`` — removes the final logit softcap. + * Intervention = doubling the pre-activation (delta = old_activation). Because the adjacency column already encodes + the full effect of the feature at its current activation level, doubling adds exactly one copy of the predicted + effect. + """ + prompt = "The capital of the state containing Dallas is" + token_x, token_y = "▁Austin", "▁Dallas" + + model, n_layers_range, _ = _cfg_backend(backend) + custom_target, idx_x, idx_y = _build_custom_diff_target( + model, prompt, token_x, token_y, backend + ) + + graph = attribute(prompt, model, attribution_targets=[custom_target], batch_size=256) + + device = next(model.parameters()).device + adjacency_matrix = graph.adjacency_matrix.to(device) + active_features = graph.active_features.to(device) + n_features = active_features.size(0) + n_logits = len(graph.logit_targets) + + # --- baseline (pre-activation, unsoftcapped) --- + with model.zero_softcap(): + logits, activation_cache = model.get_activations( + graph.input_tokens, apply_activation_function=False + ) + logits = logits.squeeze(0) + + relevant_activations = activation_cache[ + active_features[:, 0], active_features[:, 1], active_features[:, 2] + ] + baseline_logit_diff = logits[-1, idx_x] - logits[-1, idx_y] + + # --- per-feature exact checks --- + random_order = torch.randperm(n_features) + chosen_nodes = random_order[: min(n_samples, n_features)] + + for chosen_node in chosen_nodes: + layer, pos, feat_idx = active_features[chosen_node].tolist() + old_activation = activation_cache[layer, pos, feat_idx] + new_activation = old_activation * 2 + + expected_effects = adjacency_matrix[:, chosen_node] + expected_act_diff = expected_effects[:n_features] + expected_logit_diff = expected_effects[-n_logits:] # (1,) for single target + + with model.zero_softcap(): + new_logits, new_act_cache = model.feature_intervention( + graph.input_tokens, + [(layer, pos, feat_idx, new_activation)], + constrained_layers=n_layers_range, + apply_activation_function=False, + ) + new_logits = new_logits.squeeze(0) + + # -- activation check -- + assert new_act_cache is not None + new_relevant_activations = new_act_cache[ + active_features[:, 0], active_features[:, 1], active_features[:, 2] + ] + assert torch.allclose( + new_relevant_activations, + relevant_activations + expected_act_diff, + atol=act_atol, + rtol=act_rtol, + ), ( + f"Activation mismatch for feature ({layer}, {pos}, {feat_idx}): " + f"max diff = {(new_relevant_activations - relevant_activations - expected_act_diff).abs().max():.6e}" + ) + + # -- logit-difference check -- + new_logit_diff = new_logits[-1, idx_x] - new_logits[-1, idx_y] + actual_logit_change = (new_logit_diff - baseline_logit_diff).unsqueeze(0) + assert torch.allclose( + actual_logit_change, + expected_logit_diff, + atol=logit_atol, + rtol=logit_rtol, + ), ( + f"Logit-diff mismatch for feature ({layer}, {pos}, {feat_idx}): " + f"predicted={expected_logit_diff.item():.6e}, " + f"actual={actual_logit_change.item():.6e}" + ) + + +@pytest.fixture(autouse=False) +def cleanup_cuda(): + yield + gc.collect() + torch.cuda.empty_cache() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("backend", ["transformerlens", "nnsight"]) +def test_custom_target_correctness(cleanup_cuda, backend): + """Verify custom target direction feature attribution driven interventions produce expected activation/logit changes + + For a ``logit(x) − logit(y)`` custom direction, randomly samples features, doubles each feature's pre-activation + value (under constrained/frozen-layer conditions), and checks that both the activation changes and the custom + logit-difference change match the adjacency matrix predictions within acceptable tolerances. + + Args: + cleanup_cuda: Fixture for CUDA cleanup after test + backend: Model backend to test ("transformerlens" or "nnsight") + """ + _run_custom_target_correctness(backend) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("backend", ["transformerlens", "nnsight"]) +def test_attribution_format_consistency(cleanup_cuda, backend): + """Verify None, Sequence[str], and Sequence[CustomTarget] produce consistent results. + + Runs attribution with None (auto-select), then with equivalent Sequence[str] and + Sequence[CustomTarget] targets, and verifies the graphs are consistent. + + Args: + cleanup_cuda: Fixture for CUDA cleanup after test + backend: Model backend to test ("transformerlens" or "nnsight") + """ + _run_attribution_format_consistency(backend) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_attributions_gemma.py b/tests/test_attributions_gemma.py index f117bf21..d7981b9a 100644 --- a/tests/test_attributions_gemma.py +++ b/tests/test_attributions_gemma.py @@ -36,7 +36,7 @@ def verify_token_and_error_edges( s = graph.input_tokens adjacency_matrix = graph.adjacency_matrix.to(get_default_device()) active_features = graph.active_features.to(get_default_device()) - logit_tokens = graph.logit_tokens.to(get_default_device()) + logit_tokens = graph.logit_token_ids total_active_features = active_features.size(0) pos_start = 1 # ignore first token (BOS) @@ -131,7 +131,7 @@ def verify_feature_edges( s = graph.input_tokens adjacency_matrix = graph.adjacency_matrix.to(get_default_device()) active_features = graph.active_features.to(get_default_device()) - logit_tokens = graph.logit_tokens.to(get_default_device()) + logit_tokens = graph.logit_token_ids total_active_features = active_features.size(0) logits, activation_cache = model.get_activations(s, apply_activation_function=False) @@ -154,9 +154,10 @@ def verify_intervention( ) # type:ignore new_logits = new_logits.squeeze(0) - new_relevant_activations = new_activation_cache[ # type:ignore + assert new_activation_cache is not None + new_relevant_activations = new_activation_cache[ active_features[:, 0], active_features[:, 1], active_features[:, 2] - ] # type:ignore + ] new_relevant_logits = new_logits[-1, logit_tokens] new_demeaned_relevant_logits = new_relevant_logits - new_logits[-1].mean() @@ -396,7 +397,7 @@ def test_gemma_2_2b(): s = "The National Digital Analytics Group (ND" model = ReplacementModel.from_pretrained("google/gemma-2-2b", "gemma") assert isinstance(model, TransformerLensReplacementModel) - graph = attribute(s, model) + graph = attribute(s, model, batch_size=256) print("Changing logit softcap to 0, as the logits will otherwise be off.") with model.zero_softcap(): @@ -409,7 +410,7 @@ def test_gemma_2_2b_clt(): s = "The National Digital Analytics Group (ND" model = ReplacementModel.from_pretrained("google/gemma-2-2b", "mntss/clt-gemma-2-2b-426k") assert isinstance(model, TransformerLensReplacementModel) - graph = attribute(s, model) + graph = attribute(s, model, batch_size=256) print("Changing logit softcap to 0, as the logits will otherwise be off.") with model.zero_softcap(): diff --git a/tests/test_attributions_gemma3_nnsight.py b/tests/test_attributions_gemma3_nnsight.py index a3d9b9c4..d5c6f1f1 100644 --- a/tests/test_attributions_gemma3_nnsight.py +++ b/tests/test_attributions_gemma3_nnsight.py @@ -14,6 +14,7 @@ from circuit_tracer.transcoder.activation_functions import JumpReLU from circuit_tracer.transcoder.cross_layer_transcoder import CrossLayerTranscoder from circuit_tracer.replacement_model.replacement_model_nnsight import NNSightReplacementModel +from tests.conftest import has_32gb gemma_3_config_dict = { "_sliding_window_pattern": 6, @@ -265,7 +266,7 @@ def verify_feature_edges( model: NNSightReplacementModel, graph: Graph, n_samples: int = 100, - act_atol=5e-4, + act_atol=1e-3, # dummy transcoder gemma3 tests need slightly higher tolerance act_rtol=1e-5, logit_atol=1e-5, logit_rtol=1e-3, @@ -484,7 +485,7 @@ def test_gemma3_with_dummy_transcoders(): s = "The National Digital Analytics Group (ND" model = load_gemma3_with_dummy_transcoders() model.to(torch.float32) # type:ignore - graph = attribute(s, model) + graph = attribute(s, model, batch_size=256) assert isinstance(model, NNSightReplacementModel) @@ -498,7 +499,7 @@ def test_gemma3_with_dummy_clt(): s = "The National Digital Analytics Group (ND" model = load_gemma3_with_dummy_clt() model.to(torch.float32) # type:ignore - graph = attribute(s, model) + graph = attribute(s, model, batch_size=256) assert isinstance(model, NNSightReplacementModel) @@ -525,7 +526,7 @@ def test_gemma_3_1b(): verify_feature_edges(model, graph) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not has_32gb, reason="Requires >=32GB VRAM") def test_gemma_3_1b_it(): s = "user\nThe National Digital Analytics Group (ND" model = ReplacementModel.from_pretrained( @@ -543,7 +544,7 @@ def test_gemma_3_1b_it(): verify_feature_edges(model, graph) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not has_32gb, reason="Requires >=32GB VRAM") def test_gemma_3_1b_clt(): s = "The National Digital Analytics Group (ND" model = ReplacementModel.from_pretrained( @@ -561,7 +562,7 @@ def test_gemma_3_1b_clt(): verify_feature_edges(model, graph) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not has_32gb, reason="Requires >=32GB VRAM") def test_gemma_3_4b(): s = "The National Digital Analytics Group (ND" model = ReplacementModel.from_pretrained( diff --git a/tests/test_attributions_gemma_nnsight.py b/tests/test_attributions_gemma_nnsight.py index 386e669c..7474b14e 100644 --- a/tests/test_attributions_gemma_nnsight.py +++ b/tests/test_attributions_gemma_nnsight.py @@ -341,7 +341,7 @@ def test_gemma_2_2b(): model = ReplacementModel.from_pretrained("google/gemma-2-2b", "gemma", backend="nnsight") assert isinstance(model, NNSightReplacementModel) - graph = attribute(s, model) + graph = attribute(s, model, batch_size=256) print("Changing logit softcap to 0, as the logits will otherwise be off.") with model.zero_softcap(): diff --git a/tests/test_attributions_llama_nnsight.py b/tests/test_attributions_llama_nnsight.py index 52e2de39..295be5b4 100644 --- a/tests/test_attributions_llama_nnsight.py +++ b/tests/test_attributions_llama_nnsight.py @@ -11,6 +11,7 @@ from circuit_tracer.replacement_model.replacement_model_nnsight import NNSightReplacementModel from circuit_tracer.transcoder import SingleLayerTranscoder, TranscoderSet from circuit_tracer.transcoder.activation_functions import TopK +from tests.conftest import has_32gb sys.path.append(os.path.dirname(__file__)) from test_attributions_gemma_nnsight import verify_feature_edges, verify_token_and_error_edges @@ -149,7 +150,7 @@ def test_large_llama_model(): tokenizer_class.all_special_ids = original_all_special_ids # type:ignore -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not has_32gb, reason="Requires >=32GB VRAM") def test_llama_3_2_1b(): s = "The National Digital Analytics Group (ND" model = ReplacementModel.from_pretrained( diff --git a/tests/test_freeze_points.py b/tests/test_freeze_points.py index 373c9650..cfb518df 100644 --- a/tests/test_freeze_points.py +++ b/tests/test_freeze_points.py @@ -31,7 +31,8 @@ def cleanup_cuda(): # ("google/gemma-3-1b-pt", "mwhanna/gemma-scope-2-1b-pt/clt/width_262k_l0_medium_affine"), # This requires lazy loading ( "google/gemma-3-4b-pt", - "mwhanna/gemma-scope-2-4b-pt/transcoder_all/width_262k_l0_small_affine", + # we use width_16k here instead of 262k to avoid large download not used elsewhere in the test suite + "mwhanna/gemma-scope-2-4b-pt/transcoder_all/width_16k_l0_small_affine", ), ] diff --git a/tests/test_graph.py b/tests/test_graph.py index 5776850a..da3ec8fb 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -5,6 +5,7 @@ import torch from transformer_lens import HookedTransformerConfig +from circuit_tracer.attribution.targets import LogitTarget from circuit_tracer.graph import Graph, compute_edge_influence, compute_node_influence from circuit_tracer.utils import get_default_device @@ -123,7 +124,7 @@ def test_small_graph(): active_features=torch.tensor([1, 2, 3, 4, 5]), adjacency_matrix=adjacency_matrix, cfg=cfg, - logit_tokens=torch.tensor([0]), + logit_targets=[LogitTarget(token_str="tok_0", vocab_idx=0)], logit_probabilities=torch.tensor([1.0]), selected_features=torch.tensor([1, 2, 3, 4, 5]), activation_values=torch.tensor([1, 2, 3, 4, 5]) * 2, @@ -141,3 +142,217 @@ def test_small_graph(): edge_influence_on_logits = compute_edge_influence(pruned_adjacency_matrix, logit_weights) assert torch.allclose(edge_influence_on_logits, post_pruning_edge_matrix) + + +def test_graph_with_tensor_logit_targets(): + """Test that Graph accepts LogitTarget list format and from_pt handles legacy tensor format.""" + cfg = HookedTransformerConfig.from_dict( + { + "n_layers": 2, + "d_model": 8, + "n_ctx": 32, + "d_head": 4, + "n_heads": 2, + "d_mlp": 16, + "act_fn": "gelu", + "d_vocab": 50257, # GPT-2 vocab size + "model_name": "test-model", + "device": get_default_device(), + } + ) + + adjacency_matrix = torch.zeros([10, 10]) + adjacency_matrix[9, 5] = 1.0 + + # Test with LogitTarget list using empty token strings (simulates legacy conversion) + graph_tensor = Graph( + input_string="test", + input_tokens=torch.tensor([1, 2, 3]), + active_features=torch.tensor([[0, 0, 5]]), + adjacency_matrix=adjacency_matrix, + cfg=cfg, + logit_targets=[ + LogitTarget(token_str="", vocab_idx=262), + LogitTarget(token_str="", vocab_idx=290), + LogitTarget(token_str="", vocab_idx=314), + ], + logit_probabilities=torch.tensor([0.5, 0.3, 0.2]), + selected_features=torch.tensor([0]), + activation_values=torch.tensor([1.5]), + ) + + # Verify LogitTarget list format + assert len(graph_tensor.logit_targets) == 3 + assert graph_tensor.logit_targets[0].vocab_idx == 262 + assert graph_tensor.logit_targets[1].vocab_idx == 290 + assert graph_tensor.logit_targets[2].vocab_idx == 314 + # Token strings are empty when constructed from legacy tensor + assert graph_tensor.logit_targets[0].token_str == "" + assert graph_tensor.logit_targets[1].token_str == "" + assert graph_tensor.logit_targets[2].token_str == "" + + # Verify properties work + assert graph_tensor.logit_token_ids.tolist() == [262, 290, 314] + assert torch.equal(graph_tensor.logit_token_ids, torch.tensor([262, 290, 314])) + + # Test with LogitTarget list format (current) + graph_list = Graph( + input_string="test", + input_tokens=torch.tensor([1, 2, 3]), + active_features=torch.tensor([[0, 0, 5]]), + adjacency_matrix=adjacency_matrix, + cfg=cfg, + logit_targets=[ + LogitTarget(token_str=" the", vocab_idx=262), + LogitTarget(token_str=" a", vocab_idx=290), + LogitTarget(token_str=" and", vocab_idx=314), + ], + logit_probabilities=torch.tensor([0.5, 0.3, 0.2]), + selected_features=torch.tensor([0]), + activation_values=torch.tensor([1.5]), + ) + + # Verify both formats produce same logit_token_ids + assert torch.equal(graph_tensor.logit_token_ids, graph_list.logit_token_ids) + assert graph_tensor.vocab_size == graph_list.vocab_size + + +@pytest.mark.parametrize( + "logit_targets_input,expected_token_strs", + [ + pytest.param( + [ + LogitTarget(token_str="", vocab_idx=262), + LogitTarget(token_str="", vocab_idx=290), + LogitTarget(token_str="", vocab_idx=314), + ], + ["", "", ""], + id="empty_token_str_format", + ), + pytest.param( + [ + LogitTarget(token_str=" the", vocab_idx=262), + LogitTarget(token_str=" a", vocab_idx=290), + LogitTarget(token_str=" and", vocab_idx=314), + ], + [" the", " a", " and"], + id="logit_target_format", + ), + ], +) +def test_graph_serialization_with_logit_targets(logit_targets_input, expected_token_strs): + """Test that Graph serialization works with both tensor and LogitTarget formats.""" + import tempfile + import os + + cfg = HookedTransformerConfig.from_dict( + { + "n_layers": 2, + "d_model": 8, + "n_ctx": 32, + "d_head": 4, + "n_heads": 2, + "d_mlp": 16, + "act_fn": "gelu", + "d_vocab": 50257, + "model_name": "test-model", + "device": get_default_device(), + } + ) + + adjacency_matrix = torch.zeros([10, 10]) + adjacency_matrix[9, 5] = 1.0 + + # Create graph with parameterized format + original_graph = Graph( + input_string="test", + input_tokens=torch.tensor([1, 2, 3]), + active_features=torch.tensor([[0, 0, 5]]), + adjacency_matrix=adjacency_matrix, + cfg=cfg, + logit_targets=logit_targets_input, + logit_probabilities=torch.tensor([0.5, 0.3, 0.2]), + selected_features=torch.tensor([0]), + activation_values=torch.tensor([1.5]), + vocab_size=50257, + ) + + # Save and load + with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp: + tmp_path = tmp.name + + try: + original_graph.to_pt(tmp_path) + loaded_graph = Graph.from_pt(tmp_path) + + # Verify loaded graph has correct data + assert loaded_graph.logit_token_ids.tolist() == [262, 290, 314] + assert loaded_graph.vocab_size == 50257 + assert torch.equal(loaded_graph.logit_token_ids, torch.tensor([262, 290, 314])) + assert torch.equal(loaded_graph.logit_probabilities, torch.tensor([0.5, 0.3, 0.2])) + + # Verify LogitTarget objects were preserved with expected token strings + assert len(loaded_graph.logit_targets) == 3 + assert all(isinstance(lt, LogitTarget) for lt in loaded_graph.logit_targets) + assert loaded_graph.logit_targets[0].token_str == expected_token_strs[0] + assert loaded_graph.logit_targets[1].token_str == expected_token_strs[1] + assert loaded_graph.logit_targets[2].token_str == expected_token_strs[2] + + finally: + if os.path.exists(tmp_path): + os.unlink(tmp_path) + + +def test_graph_from_pt_legacy_tensor_format(): + """Test that Graph.from_pt correctly handles legacy serialized graphs with tensor logit_targets.""" + import tempfile + import os + + cfg = HookedTransformerConfig.from_dict( + { + "n_layers": 2, + "d_model": 8, + "n_ctx": 32, + "d_head": 4, + "n_heads": 2, + "d_mlp": 16, + "act_fn": "gelu", + "d_vocab": 50257, + "model_name": "test-model", + "device": get_default_device(), + } + ) + + # Simulate a legacy .pt file with tensor logit_targets + legacy_data = { + "input_string": "test", + "adjacency_matrix": torch.zeros([10, 10]), + "cfg": cfg, + "active_features": torch.tensor([[0, 0, 5]]), + "logit_targets": torch.tensor([262, 290, 314]), # Legacy tensor format + "logit_probabilities": torch.tensor([0.5, 0.3, 0.2]), + "vocab_size": 50257, + "input_tokens": torch.tensor([1, 2, 3]), + "selected_features": torch.tensor([0]), + "activation_values": torch.tensor([1.5]), + "scan": None, + } + + with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp: + tmp_path = tmp.name + + try: + torch.save(legacy_data, tmp_path) + loaded_graph = Graph.from_pt(tmp_path) + + # Verify from_pt converted tensor to LogitTarget list + assert len(loaded_graph.logit_targets) == 3 + assert all(isinstance(lt, LogitTarget) for lt in loaded_graph.logit_targets) + assert loaded_graph.logit_targets[0].vocab_idx == 262 + assert loaded_graph.logit_targets[1].vocab_idx == 290 + assert loaded_graph.logit_targets[2].vocab_idx == 314 + assert loaded_graph.logit_targets[0].token_str == "" + assert loaded_graph.logit_token_ids.tolist() == [262, 290, 314] + finally: + if os.path.exists(tmp_path): + os.unlink(tmp_path) diff --git a/tests/test_offload.py b/tests/test_offload.py index 58220a47..cbfb4bba 100644 --- a/tests/test_offload.py +++ b/tests/test_offload.py @@ -12,6 +12,7 @@ from circuit_tracer.replacement_model.replacement_model_nnsight import ( NNSightReplacementModel, ) +from tests.conftest import has_32gb @pytest.fixture(autouse=True) @@ -48,7 +49,7 @@ def test_offload_tl(): assert param.device.type == original_device.type -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not has_32gb, reason="Requires >=32GB VRAM") def test_offload_nnsight(): s = "The National Digital Analytics Group (ND" model = ReplacementModel.from_pretrained("google/gemma-2-2b", "gemma", backend="nnsight") @@ -75,7 +76,7 @@ def test_offload_nnsight(): assert param.device.type == original_device.type -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not has_32gb, reason="Requires >=32GB VRAM") def test_offload_nnsight_gemma_3(): s = "The National Digital Analytics Group (ND" model_name = "google/gemma-3-4b-pt" diff --git a/tests/test_transformerlens_nnsight_same_gemma.py b/tests/test_transformerlens_nnsight_same_gemma.py index 3094e8c5..6e40eb07 100644 --- a/tests/test_transformerlens_nnsight_same_gemma.py +++ b/tests/test_transformerlens_nnsight_same_gemma.py @@ -8,6 +8,10 @@ from circuit_tracer.attribution.attribute_transformerlens import ( attribute as attribute_transformerlens, ) +from tests.conftest import has_32gb + +# Mark all tests in this module as requiring 32GB+ VRAM +pytestmark = [pytest.mark.skipif(not has_32gb, reason="Requires >=32GB VRAM")] @pytest.fixture(autouse=True) diff --git a/tests/test_transformerlens_nnsight_same_gemma_clts.py b/tests/test_transformerlens_nnsight_same_gemma_clts.py index 40b8820a..06f288b5 100644 --- a/tests/test_transformerlens_nnsight_same_gemma_clts.py +++ b/tests/test_transformerlens_nnsight_same_gemma_clts.py @@ -8,6 +8,10 @@ from circuit_tracer.attribution.attribute_transformerlens import ( attribute as attribute_transformerlens, ) +from tests.conftest import has_32gb + +# Mark all tests in this module as requiring 32GB+ VRAM +pytestmark = [pytest.mark.skipif(not has_32gb, reason="Requires >=32GB VRAM")] @pytest.fixture(autouse=True) diff --git a/tests/test_transformerlens_nnsight_same_llama.py b/tests/test_transformerlens_nnsight_same_llama.py index f3fd2133..a2afeb92 100644 --- a/tests/test_transformerlens_nnsight_same_llama.py +++ b/tests/test_transformerlens_nnsight_same_llama.py @@ -8,6 +8,10 @@ from circuit_tracer.attribution.attribute_transformerlens import ( attribute as attribute_transformerlens, ) +from tests.conftest import has_32gb + +# Mark all tests in this module as requiring 32GB+ VRAM +pytestmark = [pytest.mark.skipif(not has_32gb, reason="Requires >=32GB VRAM")] @pytest.fixture(autouse=True) diff --git a/tests/test_transformerlens_nnsight_same_llama_clts.py b/tests/test_transformerlens_nnsight_same_llama_clts.py index 13ceef42..6b669c16 100644 --- a/tests/test_transformerlens_nnsight_same_llama_clts.py +++ b/tests/test_transformerlens_nnsight_same_llama_clts.py @@ -5,6 +5,10 @@ from circuit_tracer.replacement_model import ReplacementModel from circuit_tracer.attribution.attribute import attribute +from tests.conftest import has_32gb + +# Mark all tests in this module as requiring 32GB+ VRAM +pytestmark = [pytest.mark.skipif(not has_32gb, reason="Requires >=32GB VRAM")] @pytest.fixture(autouse=True) diff --git a/tests/test_tutorial_notebook_backends.py b/tests/test_tutorial_notebook_backends.py index bfc6d3f8..4148c5c9 100644 --- a/tests/test_tutorial_notebook_backends.py +++ b/tests/test_tutorial_notebook_backends.py @@ -1,4 +1,5 @@ import gc +from contextlib import contextmanager import pytest import torch @@ -8,6 +9,85 @@ from circuit_tracer.attribution.attribute_transformerlens import ( attribute as attribute_transformerlens, ) +from circuit_tracer.attribution.targets import CustomTarget +from circuit_tracer.graph import compute_node_influence +from circuit_tracer.utils.demo_utils import get_unembed_vecs +from tests.conftest import has_32gb + +# decorator used to gate individual tests on available VRAM +skip32gb = pytest.mark.skipif(not has_32gb, reason="Requires >=32GB VRAM") + + +def _move_replacement_model(model, device): + """Move a ReplacementModel (and its transcoders) to *device*, updating internal refs. + + Works for both NNSight and TransformerLens backends. + """ + device = torch.device(device) if isinstance(device, str) else device + + # Move model parameters + model.to(device) + + # Move transcoders — NNSight wraps them in an Envoy so .to() only takes device + try: + model.transcoders.to(device, torch.float32) + except TypeError: + model.transcoders.to(device) + + # Update stale tensor references left on the NNSight model instance. + # `.to()` replaces Parameter tensors inside the module but external refs + # (e.g. embed_weight, unembed_weight) still point at the old device. + for attr in ("embed_weight", "unembed_weight"): + t = getattr(model, attr, None) + if t is not None and t.device != device: + setattr(model, attr, t.to(device)) + + # Update backend-specific device tracking + if hasattr(model, "cfg") and hasattr(model.cfg, "device"): + # TransformerLens backend + model.cfg.device = device + + +@contextmanager +def clean_cuda(model, min_bytes: int = 1 << 20): + """Move *model* to CUDA; on exit automatically free large transient CUDA tensors. + + Snapshots data_ptrs of all large CUDA tensors after the model moves to CUDA + (capturing model weights as 'known'). On exit, any new large CUDA tensor not + in the snapshot has its storage replaced via ``set_(torch.empty(0))``, freeing + VRAM even while Python references remain alive. Then ``gc.collect()`` + + ``empty_cache()`` flush remaining allocations before the model moves back to CPU. + Callers do not need explicit ``del`` statements for large GPU-resident objects. + """ + _move_replacement_model(model, "cuda") + + def _is_large_dense_cuda(t: object) -> bool: + return ( + isinstance(t, torch.Tensor) + and t.is_cuda + and t.layout == torch.strided + and t.nbytes >= min_bytes + ) + + known_ptrs: set[int] = {obj.data_ptr() for obj in gc.get_objects() if _is_large_dense_cuda(obj)} + try: + yield + finally: + freed_ptrs: set[int] = set() + for obj in gc.get_objects(): + if ( + _is_large_dense_cuda(obj) + and obj.data_ptr() not in known_ptrs + and obj.data_ptr() not in freed_ptrs + ): + freed_ptrs.add(obj.data_ptr()) + try: + obj.set_(torch.empty(0)) + except Exception: + pass + gc.collect() + torch.cuda.empty_cache() + _move_replacement_model(model, "cpu") @pytest.fixture(autouse=True) @@ -27,6 +107,27 @@ def models(): return model_nnsight, model_tl +@pytest.fixture(scope="module") +def models_cpu(): + """Load both models on CPU for memory-constrained sequential backend tests. + + Tests using this fixture should wrap each backend run in ``clean_cuda`` + to move the active model to CUDA and restore it to CPU when done, + automatically freeing transient GPU-resident objects between backend phases. + """ + model_nnsight = ReplacementModel.from_pretrained( + "google/gemma-2-2b", + "gemma", + backend="nnsight", + dtype=torch.float32, + device=torch.device("cpu"), + ) + model_tl = ReplacementModel.from_pretrained( + "google/gemma-2-2b", "gemma", dtype=torch.float32, device=torch.device("cpu") + ) + return model_nnsight, model_tl + + @pytest.fixture def dallas_supernode_features(): """Features from Dallas-Austin circuit supernodes.""" @@ -122,6 +223,7 @@ def small_big_prompts(): } +@skip32gb @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_dallas_austin_activations(models, dallas_austin_prompt): """Test get_activations consistency for Dallas-Austin prompt.""" @@ -141,6 +243,7 @@ def test_dallas_austin_activations(models, dallas_austin_prompt): ) +@skip32gb @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_dallas_austin_attribution(models, dallas_austin_prompt): """Test attribution consistency for Dallas-Austin prompt.""" @@ -167,6 +270,7 @@ def test_dallas_austin_attribution(models, dallas_austin_prompt): ) +@skip32gb @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_dallas_intervention_say_capital_ablation( models, dallas_austin_prompt, dallas_supernode_features @@ -204,6 +308,7 @@ def test_dallas_intervention_say_capital_ablation( ) +@skip32gb @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_dallas_intervention_capital_ablation( models, dallas_austin_prompt, dallas_supernode_features @@ -240,6 +345,7 @@ def test_dallas_intervention_capital_ablation( ) +@skip32gb @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_dallas_intervention_texas_ablation( models, dallas_austin_prompt, dallas_supernode_features @@ -276,6 +382,7 @@ def test_dallas_intervention_texas_ablation( ) +@skip32gb @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_dallas_intervention_state_ablation( models, dallas_austin_prompt, dallas_supernode_features @@ -312,6 +419,7 @@ def test_dallas_intervention_state_ablation( ) +@skip32gb @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_dallas_intervention_replace_texas_with_california( models, dallas_austin_prompt, dallas_supernode_features, oakland_supernode_features @@ -361,6 +469,7 @@ def test_dallas_intervention_replace_texas_with_california( ) +@skip32gb @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_dallas_intervention_replace_texas_with_china( models, dallas_austin_prompt, dallas_supernode_features, shanghai_supernode_features @@ -409,6 +518,7 @@ def test_dallas_intervention_replace_texas_with_china( ) +@skip32gb @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_dallas_intervention_replace_texas_with_bc( models, dallas_austin_prompt, dallas_supernode_features, vancouver_supernode_features @@ -456,6 +566,7 @@ def test_dallas_intervention_replace_texas_with_bc( ) +@skip32gb @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_oakland_sacramento_activations(models, oakland_sacramento_prompt): """Test get_activations consistency for Oakland-Sacramento prompt.""" @@ -475,6 +586,7 @@ def test_oakland_sacramento_activations(models, oakland_sacramento_prompt): ) +@skip32gb @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_oakland_sacramento_attribution(models, oakland_sacramento_prompt): """Test attribution consistency for Oakland-Sacramento prompt.""" @@ -501,6 +613,7 @@ def test_oakland_sacramento_attribution(models, oakland_sacramento_prompt): ) +@skip32gb @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_multilingual_english_activations(models, small_big_prompts): """Test get_activations consistency for English opposite prompt.""" @@ -523,6 +636,7 @@ def test_multilingual_english_activations(models, small_big_prompts): ) +@skip32gb @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_multilingual_french_activations(models, small_big_prompts): """Test get_activations consistency for French opposite prompt.""" @@ -543,6 +657,7 @@ def test_multilingual_french_activations(models, small_big_prompts): ) +@skip32gb @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_multilingual_chinese_activations(models, small_big_prompts): """Test get_activations consistency for Chinese opposite prompt.""" @@ -563,6 +678,7 @@ def test_multilingual_chinese_activations(models, small_big_prompts): ) +@skip32gb @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_multilingual_french_attribution(models, small_big_prompts): """Test attribution consistency for French opposite prompt.""" @@ -590,6 +706,7 @@ def test_multilingual_french_attribution(models, small_big_prompts): ) +@skip32gb @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_multilingual_french_ablation(models, small_big_prompts, multilingual_supernode_features): """Test ablating French language features (-2x).""" @@ -625,6 +742,7 @@ def test_multilingual_french_ablation(models, small_big_prompts, multilingual_su ) +@skip32gb @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_multilingual_french_to_chinese(models, small_big_prompts, multilingual_supernode_features): """Test replacing French with Chinese (French -2x, Chinese +2x).""" @@ -672,6 +790,7 @@ def test_multilingual_french_to_chinese(models, small_big_prompts, multilingual_ ) +@skip32gb @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_multilingual_replace_small_with_big( models, small_big_prompts, multilingual_supernode_features @@ -720,6 +839,7 @@ def test_multilingual_replace_small_with_big( @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@skip32gb def test_setup_attribution_consistency(models, dallas_austin_prompt): """Test that attribution contexts are consistent between backends.""" model_nnsight, model_tl = models @@ -743,6 +863,472 @@ def test_setup_attribution_consistency(models, dallas_austin_prompt): ) +def _build_demo_custom_target(model, prompt, token_x, token_y, backend): + """Build a CustomTarget for logit(token_x) − logit(token_y). + + Backend-agnostic helper matching the attribution_targets_demo pattern. + Uses ``get_unembed_vecs`` from ``demo_utils`` for unembedding extraction. + """ + tokenizer = model.tokenizer + idx_x = tokenizer.encode(token_x, add_special_tokens=False)[-1] + idx_y = tokenizer.encode(token_y, add_special_tokens=False)[-1] + + input_ids = model.ensure_tokenized(prompt) + with torch.no_grad(): + logits, _ = model.get_activations(input_ids) + last_logits = logits.squeeze(0)[-1] + + vec_x, vec_y = get_unembed_vecs(model, [idx_x, idx_y], backend) + diff_vec = vec_x - vec_y + probs = torch.softmax(last_logits, dim=-1) + diff_prob = max((probs[idx_x] - probs[idx_y]).abs().item(), 1e-6) + + return ( + CustomTarget(token_str=f"logit({token_x})-logit({token_y})", prob=diff_prob, vec=diff_vec), + idx_x, + idx_y, + ) + + +def _build_demo_semantic_target(model, prompt, group_a_tokens, group_b_tokens, label, backend): + """Build a CustomTarget for an abstract concept direction via vector rejection. + + For each (capital, state) pair, project the capital vector onto the state + vector and subtract that projection, leaving pure "capital-ness". + + Backend-agnostic helper matching the attribution_targets_demo pattern. + """ + assert len(group_a_tokens) == len(group_b_tokens), ( + "Groups must have equal length for paired differences" + ) + tokenizer = model.tokenizer + ids_a = [tokenizer.encode(t, add_special_tokens=False)[-1] for t in group_a_tokens] + ids_b = [tokenizer.encode(t, add_special_tokens=False)[-1] for t in group_b_tokens] + + vecs_a = get_unembed_vecs(model, ids_a, backend) + vecs_b = get_unembed_vecs(model, ids_b, backend) + + # Vector rejection: for each pair, remove the state-direction component + residuals = [] + for va, vb in zip(vecs_a, vecs_b): + va_f, vb_f = va.float(), vb.float() + proj = (va_f @ vb_f) / (vb_f @ vb_f) * vb_f + residuals.append((va_f - proj).to(va.dtype)) + + direction = torch.stack(residuals).mean(0) + + input_ids = model.ensure_tokenized(prompt) + with torch.no_grad(): + logits, _ = model.get_activations(input_ids) + probs = torch.softmax(logits.squeeze(0)[-1], dim=-1) + avg_prob = max(sum(probs[i].item() for i in ids_a) / len(ids_a), 1e-6) + + return CustomTarget(token_str=label, prob=avg_prob, vec=direction) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_attribution_targets_string(models_cpu, dallas_austin_prompt): + """Test attribution with Sequence[str] targets consistency between TL and NNSight.""" + model_nnsight, model_tl = models_cpu + str_targets = ["▁Austin", "▁Dallas"] + + # --- NNSight backend --- + with clean_cuda(model_nnsight): + graph_nnsight = attribute_nnsight( + dallas_austin_prompt, + model_nnsight, + attribution_targets=str_targets, + verbose=False, + batch_size=256, + ) + nn_active = graph_nnsight.active_features.cpu() + nn_selected = graph_nnsight.selected_features.cpu() + nn_tokens = [t.token_str for t in graph_nnsight.logit_targets] + nn_adj = graph_nnsight.adjacency_matrix.cpu() + + # --- TL backend --- + with clean_cuda(model_tl): + graph_tl = attribute_transformerlens( + dallas_austin_prompt, + model_tl, + attribution_targets=str_targets, + verbose=False, + batch_size=128, + ) + tl_active = graph_tl.active_features.cpu() + tl_selected = graph_tl.selected_features.cpu() + tl_tokens = [t.token_str for t in graph_tl.logit_targets] + tl_adj = graph_tl.adjacency_matrix.cpu() + + # --- Compare CPU tensors --- + assert (nn_active == tl_active).all(), ( + "String-target active features don't match between backends" + ) + assert (nn_selected == tl_selected).all(), ( + "String-target selected features don't match between backends" + ) + assert nn_tokens == tl_tokens, f"String-target logit tokens differ: {nn_tokens} vs {tl_tokens}" + assert torch.allclose(nn_adj, tl_adj, atol=5e-4, rtol=1e-5), ( + f"String-target adjacency matrices differ by max {(nn_adj - tl_adj).abs().max()}" + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_attribution_targets_tensor(models_cpu, dallas_austin_prompt): + """Test attribution with torch.Tensor targets consistency between TL and NNSight. + + Uses the same token IDs as the string-target test (pre-tokenized equivalent). + """ + model_nnsight, model_tl = models_cpu + # Resolve token IDs for Austin and Dallas (same as string-target test) + tok = model_nnsight.tokenizer + idx_austin = tok.encode("▁Austin", add_special_tokens=False)[-1] + idx_dallas = tok.encode("▁Dallas", add_special_tokens=False)[-1] + tensor_targets = torch.tensor([idx_austin, idx_dallas]) + + # --- NNSight backend --- + with clean_cuda(model_nnsight): + graph_nnsight = attribute_nnsight( + dallas_austin_prompt, + model_nnsight, + attribution_targets=tensor_targets, + verbose=False, + batch_size=256, + ) + nn_active = graph_nnsight.active_features.cpu() + nn_selected = graph_nnsight.selected_features.cpu() + nn_tokens = [t.token_str for t in graph_nnsight.logit_targets] + nn_adj = graph_nnsight.adjacency_matrix.cpu() + + # --- TL backend --- + with clean_cuda(model_tl): + graph_tl = attribute_transformerlens( + dallas_austin_prompt, + model_tl, + attribution_targets=tensor_targets, + verbose=False, + batch_size=128, + ) + tl_active = graph_tl.active_features.cpu() + tl_selected = graph_tl.selected_features.cpu() + tl_tokens = [t.token_str for t in graph_tl.logit_targets] + tl_adj = graph_tl.adjacency_matrix.cpu() + + # --- Compare CPU tensors --- + assert (nn_active == tl_active).all(), ( + "Tensor-target active features don't match between backends" + ) + assert (nn_selected == tl_selected).all(), ( + "Tensor-target selected features don't match between backends" + ) + assert nn_tokens == tl_tokens, f"Tensor-target logit tokens differ: {nn_tokens} vs {tl_tokens}" + assert torch.allclose(nn_adj, tl_adj, atol=5e-4, rtol=1e-5), ( + f"Tensor-target adjacency matrices differ by max {(nn_adj - tl_adj).abs().max()}" + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_attribution_targets_logit_diff(models_cpu, dallas_austin_prompt): + """Test attribution with CustomTarget consistency between TL and NNSight.""" + model_nnsight, model_tl = models_cpu + + # --- NNSight backend --- + with clean_cuda(model_nnsight): + custom_nnsight, _, _ = _build_demo_custom_target( + model_nnsight, dallas_austin_prompt, "▁Austin", "▁Dallas", backend="nnsight" + ) + graph_nnsight = attribute_nnsight( + dallas_austin_prompt, + model_nnsight, + attribution_targets=[custom_nnsight], + verbose=False, + batch_size=256, + ) + nn_active = graph_nnsight.active_features.cpu() + nn_selected = graph_nnsight.selected_features.cpu() + nn_tokens = [t.token_str for t in graph_nnsight.logit_targets] + nn_adj = graph_nnsight.adjacency_matrix.cpu() + + # --- TL backend --- + with clean_cuda(model_tl): + custom_tl, _, _ = _build_demo_custom_target( + model_tl, dallas_austin_prompt, "▁Austin", "▁Dallas", backend="transformerlens" + ) + graph_tl = attribute_transformerlens( + dallas_austin_prompt, + model_tl, + attribution_targets=[custom_tl], + verbose=False, + batch_size=128, + ) + tl_active = graph_tl.active_features.cpu() + tl_selected = graph_tl.selected_features.cpu() + tl_tokens = [t.token_str for t in graph_tl.logit_targets] + tl_adj = graph_tl.adjacency_matrix.cpu() + + # --- Compare CPU tensors --- + assert (nn_active == tl_active).all(), ( + "Custom-target active features don't match between backends" + ) + assert (nn_selected == tl_selected).all(), ( + "Custom-target selected features don't match between backends" + ) + assert nn_tokens == tl_tokens, f"Custom-target logit tokens differ: {nn_tokens} vs {tl_tokens}" + assert torch.allclose(nn_adj, tl_adj, atol=5e-4, rtol=1e-5), ( + f"Custom-target adjacency matrices differ by max {(nn_adj - tl_adj).abs().max()}" + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_attribution_targets_logit_diff_intervention(models_cpu, dallas_austin_prompt): + """Test custom-target feature amplification consistency between TL and NNSight.""" + model_nnsight, model_tl = models_cpu + n_top = 10 + + def _get_top_features(graph, n): + n_logits = len(graph.logit_targets) + n_features = len(graph.selected_features) + logit_weights = torch.zeros( + graph.adjacency_matrix.shape[0], device=graph.adjacency_matrix.device + ) + logit_weights[-n_logits:] = graph.logit_probabilities + node_influence = compute_node_influence(graph.adjacency_matrix, logit_weights) + _, top_idx = torch.topk(node_influence[:n_features], min(n, n_features)) + return [tuple(graph.active_features[graph.selected_features[i]].tolist()) for i in top_idx] + + # --- NNSight backend --- + with clean_cuda(model_nnsight): + custom_nnsight, idx_x_nn, idx_y_nn = _build_demo_custom_target( + model_nnsight, dallas_austin_prompt, "▁Austin", "▁Dallas", backend="nnsight" + ) + graph_nnsight = attribute_nnsight( + dallas_austin_prompt, + model_nnsight, + attribution_targets=[custom_nnsight], + verbose=False, + batch_size=256, + ) + top_feats_nn = _get_top_features(graph_nnsight, n_top) + + input_ids_nn = model_nnsight.ensure_tokenized(dallas_austin_prompt) + orig_logits_nn, acts_nn = model_nnsight.get_activations(input_ids_nn, sparse=True) + + interv_nn = [(ly, p, f, 10.0 * acts_nn[ly, p, f]) for (ly, p, f) in top_feats_nn] + new_logits_nn, _ = model_nnsight.feature_intervention(input_ids_nn, interv_nn) + + orig_gap_nn = ( + (orig_logits_nn.squeeze(0)[-1, idx_x_nn] - orig_logits_nn.squeeze(0)[-1, idx_y_nn]) + .cpu() + .item() + ) + new_gap_nn = ( + (new_logits_nn.squeeze(0)[-1, idx_x_nn] - new_logits_nn.squeeze(0)[-1, idx_y_nn]) + .cpu() + .item() + ) + + # --- TL backend --- + with clean_cuda(model_tl): + custom_tl, idx_x_tl, idx_y_tl = _build_demo_custom_target( + model_tl, dallas_austin_prompt, "▁Austin", "▁Dallas", backend="transformerlens" + ) + graph_tl = attribute_transformerlens( + dallas_austin_prompt, + model_tl, + attribution_targets=[custom_tl], + verbose=False, + batch_size=128, + ) + top_feats_tl = _get_top_features(graph_tl, n_top) + + input_ids_tl = model_tl.ensure_tokenized(dallas_austin_prompt) + orig_logits_tl, acts_tl = model_tl.get_activations(input_ids_tl, sparse=True) + + interv_tl = [(ly, p, f, 10.0 * acts_tl[ly, p, f]) for (ly, p, f) in top_feats_tl] + new_logits_tl, _ = model_tl.feature_intervention(input_ids_tl, interv_tl) + + orig_gap_tl = ( + (orig_logits_tl.squeeze(0)[-1, idx_x_tl] - orig_logits_tl.squeeze(0)[-1, idx_y_tl]) + .cpu() + .item() + ) + new_gap_tl = ( + (new_logits_tl.squeeze(0)[-1, idx_x_tl] - new_logits_tl.squeeze(0)[-1, idx_y_tl]) + .cpu() + .item() + ) + + # --- Compare on CPU --- + assert new_gap_nn > orig_gap_nn, ( + f"NNSight: amplification should widen gap, got {orig_gap_nn:.4f} -> {new_gap_nn:.4f}" + ) + assert new_gap_tl > orig_gap_tl, ( + f"TL: amplification should widen gap, got {orig_gap_tl:.4f} -> {new_gap_tl:.4f}" + ) + + assert abs(new_gap_nn - new_gap_tl) < 0.5, ( + f"Post-intervention gaps differ too much: NNSight={new_gap_nn:.4f}, TL={new_gap_tl:.4f}" + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_attribution_targets_semantic(models_cpu, dallas_austin_prompt): + """Test attribution with semantic concept CustomTarget consistency between TL and NNSight.""" + model_nnsight, model_tl = models_cpu + capitals = ["▁Austin", "▁Sacramento", "▁Olympia", "▁Atlanta"] + states = ["▁Texas", "▁California", "▁Washington", "▁Georgia"] + label = "Concept: Capitals − States" + + # --- NNSight backend --- + with clean_cuda(model_nnsight): + sem_nnsight = _build_demo_semantic_target( + model_nnsight, dallas_austin_prompt, capitals, states, label, backend="nnsight" + ) + graph_nnsight = attribute_nnsight( + dallas_austin_prompt, + model_nnsight, + attribution_targets=[sem_nnsight], + verbose=False, + batch_size=256, + ) + nn_active = graph_nnsight.active_features.cpu() + nn_selected = graph_nnsight.selected_features.cpu() + nn_tokens = [t.token_str for t in graph_nnsight.logit_targets] + nn_adj = graph_nnsight.adjacency_matrix.cpu() + + # --- TL backend --- + with clean_cuda(model_tl): + sem_tl = _build_demo_semantic_target( + model_tl, dallas_austin_prompt, capitals, states, label, backend="transformerlens" + ) + graph_tl = attribute_transformerlens( + dallas_austin_prompt, + model_tl, + attribution_targets=[sem_tl], + verbose=False, + batch_size=128, + ) + tl_active = graph_tl.active_features.cpu() + tl_selected = graph_tl.selected_features.cpu() + tl_tokens = [t.token_str for t in graph_tl.logit_targets] + tl_adj = graph_tl.adjacency_matrix.cpu() + + # --- Compare CPU tensors --- + assert (nn_active == tl_active).all(), ( + "Semantic-target active features don't match between backends" + ) + assert (nn_selected == tl_selected).all(), ( + "Semantic-target selected features don't match between backends" + ) + assert nn_tokens == tl_tokens, ( + f"Semantic-target logit tokens differ: {nn_tokens} vs {tl_tokens}" + ) + assert torch.allclose(nn_adj, tl_adj, atol=5e-4, rtol=1e-5), ( + f"Semantic-target adjacency matrices differ by max {(nn_adj - tl_adj).abs().max()}" + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_attribution_targets_semantic_intervention(models_cpu, dallas_austin_prompt): + """Test semantic-target feature amplification consistency between TL and NNSight.""" + model_nnsight, model_tl = models_cpu + n_top = 10 + capitals = ["▁Austin", "▁Sacramento", "▁Olympia", "▁Atlanta"] + states = ["▁Texas", "▁California", "▁Washington", "▁Georgia"] + label = "Concept: Capitals − States" + + def _get_top_features(graph, n): + n_logits = len(graph.logit_targets) + n_features = len(graph.selected_features) + logit_weights = torch.zeros( + graph.adjacency_matrix.shape[0], device=graph.adjacency_matrix.device + ) + logit_weights[-n_logits:] = graph.logit_probabilities + node_influence = compute_node_influence(graph.adjacency_matrix, logit_weights) + _, top_idx = torch.topk(node_influence[:n_features], min(n, n_features)) + return [tuple(graph.active_features[graph.selected_features[i]].tolist()) for i in top_idx] + + # --- NNSight backend --- + with clean_cuda(model_nnsight): + sem_nnsight = _build_demo_semantic_target( + model_nnsight, dallas_austin_prompt, capitals, states, label, backend="nnsight" + ) + idx_x_nn = model_nnsight.tokenizer.encode("▁Austin", add_special_tokens=False)[-1] + idx_y_nn = model_nnsight.tokenizer.encode("▁Dallas", add_special_tokens=False)[-1] + + graph_nnsight = attribute_nnsight( + dallas_austin_prompt, + model_nnsight, + attribution_targets=[sem_nnsight], + verbose=False, + batch_size=256, + ) + top_feats_nn = _get_top_features(graph_nnsight, n_top) + + input_ids_nn = model_nnsight.ensure_tokenized(dallas_austin_prompt) + orig_logits_nn, acts_nn = model_nnsight.get_activations(input_ids_nn, sparse=True) + + interv_nn = [(ly, p, f, 10.0 * acts_nn[ly, p, f]) for (ly, p, f) in top_feats_nn] + new_logits_nn, _ = model_nnsight.feature_intervention(input_ids_nn, interv_nn) + + orig_gap_nn = ( + (orig_logits_nn.squeeze(0)[-1, idx_x_nn] - orig_logits_nn.squeeze(0)[-1, idx_y_nn]) + .cpu() + .item() + ) + new_gap_nn = ( + (new_logits_nn.squeeze(0)[-1, idx_x_nn] - new_logits_nn.squeeze(0)[-1, idx_y_nn]) + .cpu() + .item() + ) + + # --- TL backend --- + with clean_cuda(model_tl): + sem_tl = _build_demo_semantic_target( + model_tl, dallas_austin_prompt, capitals, states, label, backend="transformerlens" + ) + idx_x_tl = model_tl.tokenizer.encode("▁Austin", add_special_tokens=False)[-1] + idx_y_tl = model_tl.tokenizer.encode("▁Dallas", add_special_tokens=False)[-1] + + graph_tl = attribute_transformerlens( + dallas_austin_prompt, + model_tl, + attribution_targets=[sem_tl], + verbose=False, + batch_size=128, + ) + top_feats_tl = _get_top_features(graph_tl, n_top) + + input_ids_tl = model_tl.ensure_tokenized(dallas_austin_prompt) + orig_logits_tl, acts_tl = model_tl.get_activations(input_ids_tl, sparse=True) + + interv_tl = [(ly, p, f, 10.0 * acts_tl[ly, p, f]) for (ly, p, f) in top_feats_tl] + new_logits_tl, _ = model_tl.feature_intervention(input_ids_tl, interv_tl) + + orig_gap_tl = ( + (orig_logits_tl.squeeze(0)[-1, idx_x_tl] - orig_logits_tl.squeeze(0)[-1, idx_y_tl]) + .cpu() + .item() + ) + new_gap_tl = ( + (new_logits_tl.squeeze(0)[-1, idx_x_tl] - new_logits_tl.squeeze(0)[-1, idx_y_tl]) + .cpu() + .item() + ) + + # --- Compare on CPU --- + assert new_gap_nn > orig_gap_nn, ( + f"NNSight: semantic amplification should widen gap, got {orig_gap_nn:.4f} -> {new_gap_nn:.4f}" + ) + assert new_gap_tl > orig_gap_tl, ( + f"TL: semantic amplification should widen gap, got {orig_gap_tl:.4f} -> {new_gap_tl:.4f}" + ) + + assert abs(new_gap_nn - new_gap_tl) < 0.5, ( + f"Semantic post-intervention gaps differ too much: NNSight={new_gap_nn:.4f}, TL={new_gap_tl:.4f}" + ) + + def run_all_tests(): """Run all tests when script is executed directly.""" print("Loading models...") @@ -899,9 +1485,31 @@ def run_all_tests(): test_setup_attribution_consistency(models_fixture, dallas_austin) print("✓ Attribution setup consistency test passed") + print("\n=== Testing Attribution Targets Demo ===") + + print("Running test_attribution_targets_string...") + test_attribution_targets_string(models_fixture, dallas_austin) + print("✓ Attribution targets string test passed") + + print("Running test_attribution_targets_logit_diff...") + test_attribution_targets_logit_diff(models_fixture, dallas_austin) + print("✓ Attribution targets logit-diff test passed") + + print("Running test_attribution_targets_logit_diff_intervention...") + test_attribution_targets_logit_diff_intervention(models_fixture, dallas_austin) + print("✓ Attribution targets logit-diff intervention test passed") + + print("Running test_attribution_targets_semantic...") + test_attribution_targets_semantic(models_fixture, dallas_austin) + print("✓ Attribution targets semantic test passed") + + print("Running test_attribution_targets_semantic_intervention...") + test_attribution_targets_semantic_intervention(models_fixture, dallas_austin) + print("✓ Attribution targets semantic intervention test passed") + print("\n" + "=" * 70) print("All tutorial notebook tests passed! ✓") - print("Total tests run: 20") + print("Total tests run: 24") print("=" * 70)