Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
2149aa0
Release: 2025-10-21
hannamw Oct 21, 2025
b6bf15f
A proposed API enhancement (`AttributionTargets`) that encapsulates t…
speediedan Nov 6, 2025
2b52ba6
slight clarification in a couple comments based on copilot review
speediedan Nov 7, 2025
418b17f
Allow `offload_modules` to handle single module and container offload…
speediedan Dec 16, 2025
45d6eee
use a smaller non-default batch_size for test_gemma_2_2b to expand th…
speediedan Dec 16, 2025
3d8264a
initial changes to adapt original `AttributionTargets` encapsulation …
speediedan Jan 13, 2026
32f79a4
Merge remote-tracking branch 'upstream/main' into attribution-targets
speediedan Jan 13, 2026
f90a7a4
Add pytest markers for long-running and high memory tests; adjust bat…
speediedan Jan 15, 2026
af101cf
Update unembedding matrix handling to auto-detect backend-variant ori…
speediedan Jan 15, 2026
6f93ef4
Merge branch 'main' into attribution-targets
speediedan Jan 16, 2026
2f8eb2b
minor type fix, clarify current vram gating mark
speediedan Jan 17, 2026
6b76d08
adds integration tests, refactors proposed interface incorporating PR…
speediedan Feb 10, 2026
c8b4425
remove marks, add vram skipif conditions
speediedan Feb 11, 2026
f4ad1e9
revert comment hunks to submit in a separate PR
speediedan Feb 11, 2026
787cf18
adjust `test_custom_target_correctness` to adopt our standard attribu…
speediedan Feb 12, 2026
7ea7f14
very rough exploratory draft of attribution_targets_demo.ipynb to sol…
speediedan Feb 18, 2026
be7654d
updates to attribution_targets_demo.ipynb including:
speediedan Feb 19, 2026
440f097
cleanup language and formatting, prettify with banner
speediedan Feb 20, 2026
8219073
streamlined serial backend testing with `models_cpu` fixture and `cle…
speediedan Feb 21, 2026
100af4d
restructured the demo to lead with the simpler target modes and extra…
speediedan Feb 21, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions circuit_tracer/attribution/attribute.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
"""
Unified attribution interface that routes to the correct implementation based on the ReplacementModel backend.
Unified attribution interface that routes to the correct backend implementation.

This module provides a unified entry point for computing attribution graphs,
automatically dispatching to either the TransformerLens or NNSight implementation
based on the backend type of the provided ReplacementModel.
"""

from collections.abc import Sequence
from typing import TYPE_CHECKING, Literal

import torch
Expand All @@ -19,6 +24,9 @@ def attribute(
prompt: str | torch.Tensor | list[int],
model: "NNSightReplacementModel | TransformerLensReplacementModel",
*,
attribution_targets: (
Sequence[tuple[str, float, torch.Tensor] | int | str] | torch.Tensor | None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be better here to just commit to a format! i.e. input either None, a Sequence of strs, or a sequence of fully specified attribution targets. This cuts down on the number of cases / potential confusion.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense! Supporting the mixed-mode input is probably not worth the marginal complexity and potential user confusion. I've dropped the mixed-mode list[int | str | tuple] entirely.

The revised interface for attribution_targets I'm proposing now is:
Sequence[str] | Sequence[TargetSpec] | torch.Tensor | None

  • Sequence[str] — token strings (AttributionTargets handles tokenization, we warn on multi-token strings and take the final token, lmk if you prefer we error instead or only take the final token if a bool flag is set)
  • Sequence[TargetSpec] — fully specified custom targets with explicit token_str, probability, and unembed vectors. I thought it was worth making the target specification a named tuple for clarity. We still accept regular tuples and create named tuples for the user if preferred but keep the signature readable with this alias:
    TargetSpec = CustomTarget | tuple[str, float, torch.Tensor]
  • torch.Tensor — My reasoning in accepting the torch.Tensor input is that for it's convenient for downstream consumers that have pre-tokenized IDs they'd like to attribute without roundtripping through strings. The world model analysis framework I'm building does this, invoking circuit-tracer's attribute via composable operations with pre-tokenized tensor inputs. If you feel strongly about this we can drop it and require users to convert to Sequence[str] or Sequence[TargetSpec] but I think it's a nice-to-have to allow greater breadth of downstream use cases and doesn't add much complexity.
  • None — auto-select salient logits

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds okay to me - I recognize that having multiple input formats could be convenient (even though I would like to cut down on the number of formats). Re: warning vs. erroring, I'd prefer to error here. If anything, I think I'd want to take the first token, but erroring still seems best.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense! Changed our multi-token string handling from warnings.warn() to raise ValueError().

) = None,
max_n_logits: int = 10,
desired_logit_prob: float = 0.95,
batch_size: int = 512,
Expand All @@ -35,8 +43,19 @@ def attribute(
Args:
prompt: Text, token ids, or tensor - will be tokenized if str.
model: Frozen ``ReplacementModel`` (either nnsight or transformerlens backend)
max_n_logits: Max number of logit nodes.
desired_logit_prob: Keep logits until cumulative prob >= this value.
attribution_targets: Flexible attribution target specification in one of several formats:
- None: Auto-select salient logits based on probability threshold
- torch.Tensor: Tensor of token indices
- Sequence[tuple[str, float, torch.Tensor] | int | str]: Sequence where
each element can be:
* int or str: Token ID/string (auto-resolves probability and
unembed vector)
* tuple[str, float, torch.Tensor]: Fully specified logit spec with
arbitrary string tokens (or functions thereof) that may not be in
vocabulary
max_n_logits: Max number of logit nodes (used when attribution_targets is None).
desired_logit_prob: Keep logits until cumulative prob >= this value
(used when attribution_targets is None).
batch_size: How many source nodes to process per backward pass.
max_feature_nodes: Max number of feature nodes to include in the graph.
offload: Method for offloading model parameters to save memory.
Expand All @@ -55,6 +74,7 @@ def attribute(
return attribute_nnsight(
prompt=prompt,
model=model, # type: ignore[arg-type]
attribution_targets=attribution_targets,
max_n_logits=max_n_logits,
desired_logit_prob=desired_logit_prob,
batch_size=batch_size,
Expand All @@ -69,6 +89,7 @@ def attribute(
return attribute_transformerlens(
prompt=prompt,
model=model, # type: ignore[arg-type]
attribution_targets=attribution_targets,
max_n_logits=max_n_logits,
desired_logit_prob=desired_logit_prob,
batch_size=batch_size,
Expand Down
59 changes: 40 additions & 19 deletions circuit_tracer/attribution/attribute_nnsight.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Build an **attribution graph** that captures the *direct*, *linear* effects
between features and next-token logits for a *prompt-specific*
**local replacement model**.
**local replacement model** using the NNSight backend.

High-level algorithm (matches the 2025 ``Attribution Graphs`` paper):
https://transformer-circuits.pub/2025/attribution-graphs/methods.html
Expand All @@ -22,21 +22,25 @@

import logging
import time
from typing import Literal
from collections.abc import Sequence
from typing import Literal, cast

import torch
from tqdm import tqdm

from circuit_tracer.attribution.targets import AttributionTargets
from circuit_tracer.graph import Graph, compute_partial_influences
from circuit_tracer.replacement_model.replacement_model_nnsight import NNSightReplacementModel
from circuit_tracer.utils.disk_offload import offload_modules
from circuit_tracer.utils.salient_logits import compute_salient_logits


def attribute(
prompt: str | torch.Tensor | list[int],
model: NNSightReplacementModel,
*,
attribution_targets: (
Sequence[tuple[str, float, torch.Tensor] | int | str] | torch.Tensor | None
) = None,
max_n_logits: int = 10,
desired_logit_prob: float = 0.95,
batch_size: int = 512,
Expand All @@ -45,13 +49,24 @@ def attribute(
verbose: bool = False,
update_interval: int = 4,
) -> Graph:
"""Compute an attribution graph for *prompt*.
"""Compute an attribution graph for *prompt* using NNSight backend.

Args:
prompt: Text, token ids, or tensor - will be tokenized if str.
model: Frozen ``NNSightReplacementModel``
max_n_logits: Max number of logit nodes.
desired_logit_prob: Keep logits until cumulative prob >= this value.
attribution_targets: Flexible attribution target specification in one of several formats:
- None: Auto-select salient logits based on probability threshold
- torch.Tensor: Tensor of token indices
- Sequence[tuple[str, float, torch.Tensor] | int | str]: Sequence where
each element can be:
* int or str: Token ID/string (auto-resolves probability and
unembed vector)
* tuple[str, float, torch.Tensor]: Fully specified logit spec with
arbitrary string tokens (or functions thereof) that may not be in
vocabulary
max_n_logits: Max number of logit nodes (used when attribution_targets is None).
desired_logit_prob: Keep logits until cumulative prob >= this value
(used when attribution_targets is None).
batch_size: How many source nodes to process per backward pass.
max_feature_nodes: Max number of feature nodes to include in the graph.
offload: Method for offloading model parameters to save memory.
Expand Down Expand Up @@ -81,6 +96,7 @@ def attribute(
return _run_attribution(
model=model,
prompt=prompt,
attribution_targets=attribution_targets,
max_n_logits=max_n_logits,
desired_logit_prob=desired_logit_prob,
batch_size=batch_size,
Expand All @@ -102,6 +118,7 @@ def attribute(
def _run_attribution(
model: NNSightReplacementModel,
prompt,
attribution_targets,
max_n_logits: int,
desired_logit_prob: float,
batch_size: int,
Expand Down Expand Up @@ -156,15 +173,21 @@ def _run_attribution(
n_layers, n_pos, _ = activation_matrix.shape
total_active_feats = activation_matrix._nnz()

logit_idx, logit_p, logit_vecs = compute_salient_logits(
ctx.logits[0, -1],
model.unembed_weight, # type: ignore
# Create AttributionTargets using NNSight's unembed_weight accessor
targets = AttributionTargets(
attribution_targets=attribution_targets,
logits=ctx.logits[0, -1],
unembed_proj=cast(torch.Tensor, model.unembed_weight), # NNSight uses unembed_weight
tokenizer=model.tokenizer,
max_n_logits=max_n_logits,
desired_logit_prob=desired_logit_prob,
)
logger.info(
f"Selected {len(logit_idx)} logits with cumulative probability {logit_p.sum().item():.4f}"
)

if attribution_targets is None:
logger.info(
f"Selected {len(targets)} logits with cumulative probability "
f"{targets.logit_probabilities.sum().item():.4f}"
)

if offload:
offload_handles += offload_modules([model.embed_location], offload)
Expand All @@ -176,8 +199,7 @@ def _run_attribution(
offload_handles += offload_modules([model.lm_head], offload)

logit_offset = len(feat_layers) + (n_layers + 1) * n_pos
logit_offset = logit_offset
n_logits = len(logit_idx)
n_logits = len(targets)
total_nodes = logit_offset + n_logits

actual_max_feature_nodes = min(max_feature_nodes or total_active_feats, total_active_feats)
Expand All @@ -193,8 +215,8 @@ def _run_attribution(
logger.info("Phase 3: Computing logit attributions")
phase3_start = time.time()
i = -1
for i in range(0, len(logit_idx), batch_size):
batch = logit_vecs[i : i + batch_size]
for i in range(0, len(targets), batch_size):
batch = targets.logit_vectors[i : i + batch_size]
rows = ctx.compute_batch(
layers=torch.full((batch.shape[0],), n_layers),
positions=torch.full((batch.shape[0],), n_pos - 1),
Expand Down Expand Up @@ -225,7 +247,7 @@ def _run_attribution(
pending = torch.arange(total_active_feats)
else:
influences = compute_partial_influences(
edge_matrix[:st], logit_p, row_to_node_index[:st]
edge_matrix[:st], targets.logit_probabilities, row_to_node_index[:st]
)
feature_rank = torch.argsort(influences[:total_active_feats], descending=True).cpu()
queue_size = min(update_interval * batch_size, actual_max_feature_nodes - n_visited)
Expand Down Expand Up @@ -270,8 +292,7 @@ def _run_attribution(
graph = Graph(
input_string=model.tokenizer.decode(input_ids),
input_tokens=input_ids,
logit_tokens=logit_idx,
logit_probabilities=logit_p,
attribution_targets=targets,
active_features=activation_matrix.indices().T,
activation_values=activation_matrix.values(),
selected_features=selected_features,
Expand Down
57 changes: 39 additions & 18 deletions circuit_tracer/attribution/attribute_transformerlens.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Build an **attribution graph** that captures the *direct*, *linear* effects
between features and next-token logits for a *prompt-specific*
**local replacement model**.
**local replacement model** using the TransformerLens backend.

High-level algorithm (matches the 2025 ``Attribution Graphs`` paper):
https://transformer-circuits.pub/2025/attribution-graphs/methods.html
Expand All @@ -22,23 +22,27 @@

import logging
import time
from collections.abc import Sequence
from typing import Literal

import torch
from tqdm import tqdm

from circuit_tracer.attribution.targets import AttributionTargets
from circuit_tracer.graph import Graph, compute_partial_influences
from circuit_tracer.replacement_model.replacement_model_transformerlens import (
TransformerLensReplacementModel,
)
from circuit_tracer.utils.disk_offload import offload_modules
from circuit_tracer.utils.salient_logits import compute_salient_logits


def attribute(
prompt: str | torch.Tensor | list[int],
model: TransformerLensReplacementModel,
*,
attribution_targets: (
Sequence[tuple[str, float, torch.Tensor] | int | str] | torch.Tensor | None
) = None,
max_n_logits: int = 10,
desired_logit_prob: float = 0.95,
batch_size: int = 512,
Expand All @@ -47,13 +51,24 @@ def attribute(
verbose: bool = False,
update_interval: int = 4,
) -> Graph:
"""Compute an attribution graph for *prompt*.
"""Compute an attribution graph for *prompt* using TransformerLens backend.

Args:
prompt: Text, token ids, or tensor - will be tokenized if str.
model: Frozen ``ReplacementModel``
max_n_logits: Max number of logit nodes.
desired_logit_prob: Keep logits until cumulative prob >= this value.
model: Frozen ``TransformerLensReplacementModel``
attribution_targets: Flexible attribution target specification in one of several formats:
- None: Auto-select salient logits based on probability threshold
- torch.Tensor: Tensor of token indices
- Sequence[tuple[str, float, torch.Tensor] | int | str]: Sequence where
each element can be:
* int or str: Token ID/string (auto-resolves probability and
unembed vector)
* tuple[str, float, torch.Tensor]: Fully specified logit spec with
arbitrary string tokens (or functions thereof) that may not be in
vocabulary
max_n_logits: Max number of logit nodes (used when attribution_targets is None).
desired_logit_prob: Keep logits until cumulative prob >= this value
(used when attribution_targets is None).
batch_size: How many source nodes to process per backward pass.
max_feature_nodes: Max number of feature nodes to include in the graph.
offload: Method for offloading model parameters to save memory.
Expand Down Expand Up @@ -83,6 +98,7 @@ def attribute(
return _run_attribution(
model=model,
prompt=prompt,
attribution_targets=attribution_targets,
max_n_logits=max_n_logits,
desired_logit_prob=desired_logit_prob,
batch_size=batch_size,
Expand All @@ -104,6 +120,7 @@ def attribute(
def _run_attribution(
model,
prompt,
attribution_targets,
max_n_logits,
desired_logit_prob,
batch_size,
Expand Down Expand Up @@ -147,21 +164,26 @@ def _run_attribution(
n_layers, n_pos, _ = activation_matrix.shape
total_active_feats = activation_matrix._nnz()

logit_idx, logit_p, logit_vecs = compute_salient_logits(
ctx.logits[0, -1],
model.unembed.W_U,
targets = AttributionTargets(
attribution_targets=attribution_targets,
logits=ctx.logits[0, -1],
unembed_proj=model.unembed.W_U,
tokenizer=model.tokenizer,
max_n_logits=max_n_logits,
desired_logit_prob=desired_logit_prob,
)
logger.info(
f"Selected {len(logit_idx)} logits with cumulative probability {logit_p.sum().item():.4f}"
)

if attribution_targets is None:
logger.info(
f"Selected {len(targets)} logits with cumulative probability "
f"{targets.logit_probabilities.sum().item():.4f}"
)

if offload:
offload_handles += offload_modules([model.unembed, model.embed], offload)

logit_offset = len(feat_layers) + (n_layers + 1) * n_pos
n_logits = len(logit_idx)
n_logits = len(targets)
total_nodes = logit_offset + n_logits

max_feature_nodes = min(max_feature_nodes or total_active_feats, total_active_feats)
Expand All @@ -176,8 +198,8 @@ def _run_attribution(
# Phase 3: logit attribution
logger.info("Phase 3: Computing logit attributions")
phase_start = time.time()
for i in range(0, len(logit_idx), batch_size):
batch = logit_vecs[i : i + batch_size]
for i in range(0, len(targets), batch_size):
batch = targets.logit_vectors[i : i + batch_size]
rows = ctx.compute_batch(
layers=torch.full((batch.shape[0],), n_layers),
positions=torch.full((batch.shape[0],), n_pos - 1),
Expand All @@ -203,7 +225,7 @@ def _run_attribution(
pending = torch.arange(total_active_feats)
else:
influences = compute_partial_influences(
edge_matrix[:st], logit_p, row_to_node_index[:st]
edge_matrix[:st], targets.logit_probabilities, row_to_node_index[:st]
)
feature_rank = torch.argsort(influences[:total_active_feats], descending=True).cpu()
queue_size = min(update_interval * batch_size, max_feature_nodes - n_visited)
Expand Down Expand Up @@ -248,8 +270,7 @@ def _run_attribution(
graph = Graph(
input_string=model.tokenizer.decode(input_ids),
input_tokens=input_ids,
logit_tokens=logit_idx,
logit_probabilities=logit_p,
attribution_targets=targets,
active_features=activation_matrix.indices().T,
activation_values=activation_matrix.values(),
selected_features=selected_features,
Expand Down
Loading