Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
2149aa0
Release: 2025-10-21
hannamw Oct 21, 2025
b6bf15f
A proposed API enhancement (`AttributionTargets`) that encapsulates t…
speediedan Nov 6, 2025
2b52ba6
slight clarification in a couple comments based on copilot review
speediedan Nov 7, 2025
418b17f
Allow `offload_modules` to handle single module and container offload…
speediedan Dec 16, 2025
45d6eee
use a smaller non-default batch_size for test_gemma_2_2b to expand th…
speediedan Dec 16, 2025
3d8264a
initial changes to adapt original `AttributionTargets` encapsulation …
speediedan Jan 13, 2026
32f79a4
Merge remote-tracking branch 'upstream/main' into attribution-targets
speediedan Jan 13, 2026
f90a7a4
Add pytest markers for long-running and high memory tests; adjust bat…
speediedan Jan 15, 2026
af101cf
Update unembedding matrix handling to auto-detect backend-variant ori…
speediedan Jan 15, 2026
6f93ef4
Merge branch 'main' into attribution-targets
speediedan Jan 16, 2026
2f8eb2b
minor type fix, clarify current vram gating mark
speediedan Jan 17, 2026
6b76d08
adds integration tests, refactors proposed interface incorporating PR…
speediedan Feb 10, 2026
c8b4425
remove marks, add vram skipif conditions
speediedan Feb 11, 2026
f4ad1e9
revert comment hunks to submit in a separate PR
speediedan Feb 11, 2026
787cf18
adjust `test_custom_target_correctness` to adopt our standard attribu…
speediedan Feb 12, 2026
7ea7f14
very rough exploratory draft of attribution_targets_demo.ipynb to sol…
speediedan Feb 18, 2026
be7654d
updates to attribution_targets_demo.ipynb including:
speediedan Feb 19, 2026
440f097
cleanup language and formatting, prettify with banner
speediedan Feb 20, 2026
8219073
streamlined serial backend testing with `models_cpu` fixture and `cle…
speediedan Feb 21, 2026
100af4d
restructured the demo to lead with the simpler target modes and extra…
speediedan Feb 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 35 additions & 47 deletions circuit_tracer/attribution/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,45 +27,13 @@
import torch
from tqdm import tqdm

from circuit_tracer.attribution.targets import AttributionTargets
from circuit_tracer.graph import Graph
from circuit_tracer.replacement_model import ReplacementModel
from circuit_tracer.utils import get_default_device
from circuit_tracer.utils.disk_offload import offload_modules


@torch.no_grad()
def compute_salient_logits(
logits: torch.Tensor,
unembed_proj: torch.Tensor,
*,
max_n_logits: int = 10,
desired_logit_prob: float = 0.95,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Pick the smallest logit set whose cumulative prob >= *desired_logit_prob*.

Args:
logits: ``(d_vocab,)`` vector (single position).
unembed_proj: ``(d_model, d_vocab)`` unembedding matrix.
max_n_logits: Hard cap *k*.
desired_logit_prob: Cumulative probability threshold *p*.

Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
* logit_indices - ``(k,)`` vocabulary ids.
* logit_probs - ``(k,)`` softmax probabilities.
* demeaned_vecs - ``(k, d_model)`` unembedding columns, demeaned.
"""

probs = torch.softmax(logits, dim=-1)
top_p, top_idx = torch.topk(probs, max_n_logits)
cutoff = int(torch.searchsorted(torch.cumsum(top_p, 0), desired_logit_prob)) + 1
top_p, top_idx = top_p[:cutoff], top_idx[:cutoff]

cols = unembed_proj[:, top_idx]
demeaned = cols - unembed_proj.mean(dim=-1, keepdim=True)
return top_idx, top_p, demeaned.T


def compute_partial_influences(edge_matrix, logit_p, row_to_node_index, max_iter=128, device=None):
"""Compute partial influences using power iteration method."""
device = device or get_default_device()
Expand Down Expand Up @@ -93,6 +61,9 @@ def attribute(
prompt: str | torch.Tensor | list[int],
model: ReplacementModel,
*,
attribution_targets: (
list[tuple[str, float, torch.Tensor] | int | str] | torch.Tensor | None
) = None,
max_n_logits: int = 10,
desired_logit_prob: float = 0.95,
batch_size: int = 512,
Expand All @@ -106,8 +77,19 @@ def attribute(
Args:
prompt: Text, token ids, or tensor - will be tokenized if str.
model: Frozen ``ReplacementModel``
max_n_logits: Max number of logit nodes.
desired_logit_prob: Keep logits until cumulative prob >= this value.
attribution_targets: Flexible attribution target specification in one of several formats:
- None: Auto-select salient logits based on probability threshold
- torch.Tensor: Tensor of token indices
- list[tuple[str, float, torch.Tensor] | int | str]: List where
each element can be:
* int or str: Token ID/string (auto-computes probability & vector,
returns tensor of indices)
* tuple[str, float, torch.Tensor]: Fully specified logit spec with
arbitrary string tokens (or functions thereof) that may not be in
vocabulary
max_n_logits: Max number of logit nodes (used when attribution_targets is None).
desired_logit_prob: Keep logits until cumulative prob >= this value
(used when attribution_targets is None).
batch_size: How many source nodes to process per backward pass.
max_feature_nodes: Max number of feature nodes to include in the graph.
offload: Method for offloading model parameters to save memory.
Expand Down Expand Up @@ -137,6 +119,7 @@ def attribute(
return _run_attribution(
model=model,
prompt=prompt,
attribution_targets=attribution_targets,
max_n_logits=max_n_logits,
desired_logit_prob=desired_logit_prob,
batch_size=batch_size,
Expand All @@ -158,6 +141,7 @@ def attribute(
def _run_attribution(
model,
prompt,
attribution_targets,
max_n_logits,
desired_logit_prob,
batch_size,
Expand Down Expand Up @@ -201,21 +185,26 @@ def _run_attribution(
n_layers, n_pos, _ = activation_matrix.shape
total_active_feats = activation_matrix._nnz()

logit_idx, logit_p, logit_vecs = compute_salient_logits(
ctx.logits[0, -1],
model.unembed.W_U,
targets = AttributionTargets(
attribution_targets=attribution_targets,
logits=ctx.logits[0, -1],
unembed_proj=model.unembed.W_U,
tokenizer=model.tokenizer,
max_n_logits=max_n_logits,
desired_logit_prob=desired_logit_prob,
)
logger.info(
f"Selected {len(logit_idx)} logits with cumulative probability {logit_p.sum().item():.4f}"
)

if attribution_targets is None:
logger.info(
f"Selected {len(targets)} logits with cumulative probability "
f"{targets.logit_probabilities.sum().item():.4f}"
)

if offload:
offload_handles += offload_modules([model.unembed, model.embed], offload)

logit_offset = len(feat_layers) + (n_layers + 1) * n_pos
n_logits = len(logit_idx)
n_logits = len(targets)
total_nodes = logit_offset + n_logits

max_feature_nodes = min(max_feature_nodes or total_active_feats, total_active_feats)
Expand All @@ -230,8 +219,8 @@ def _run_attribution(
# Phase 3: logit attribution
logger.info("Phase 3: Computing logit attributions")
phase_start = time.time()
for i in range(0, len(logit_idx), batch_size):
batch = logit_vecs[i : i + batch_size]
for i in range(0, len(targets), batch_size):
batch = targets.logit_vectors[i : i + batch_size]
rows = ctx.compute_batch(
layers=torch.full((batch.shape[0],), n_layers),
positions=torch.full((batch.shape[0],), n_pos - 1),
Expand All @@ -257,7 +246,7 @@ def _run_attribution(
pending = torch.arange(total_active_feats)
else:
influences = compute_partial_influences(
edge_matrix[:st], logit_p, row_to_node_index[:st]
edge_matrix[:st], targets.logit_probabilities, row_to_node_index[:st]
)
feature_rank = torch.argsort(influences[:total_active_feats], descending=True).cpu()
queue_size = min(update_interval * batch_size, max_feature_nodes - n_visited)
Expand Down Expand Up @@ -302,8 +291,7 @@ def _run_attribution(
graph = Graph(
input_string=model.tokenizer.decode(input_ids),
input_tokens=input_ids,
logit_tokens=logit_idx,
logit_probabilities=logit_p,
attribution_targets=targets,
active_features=activation_matrix.indices().T,
activation_values=activation_matrix.values(),
selected_features=selected_features,
Expand Down
Loading
Loading