Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
2149aa0
Release: 2025-10-21
hannamw Oct 21, 2025
b6bf15f
A proposed API enhancement (`AttributionTargets`) that encapsulates t…
speediedan Nov 6, 2025
2b52ba6
slight clarification in a couple comments based on copilot review
speediedan Nov 7, 2025
418b17f
Allow `offload_modules` to handle single module and container offload…
speediedan Dec 16, 2025
45d6eee
use a smaller non-default batch_size for test_gemma_2_2b to expand th…
speediedan Dec 16, 2025
3d8264a
initial changes to adapt original `AttributionTargets` encapsulation …
speediedan Jan 13, 2026
32f79a4
Merge remote-tracking branch 'upstream/main' into attribution-targets
speediedan Jan 13, 2026
f90a7a4
Add pytest markers for long-running and high memory tests; adjust bat…
speediedan Jan 15, 2026
af101cf
Update unembedding matrix handling to auto-detect backend-variant ori…
speediedan Jan 15, 2026
6f93ef4
Merge branch 'main' into attribution-targets
speediedan Jan 16, 2026
2f8eb2b
minor type fix, clarify current vram gating mark
speediedan Jan 17, 2026
6b76d08
adds integration tests, refactors proposed interface incorporating PR…
speediedan Feb 10, 2026
c8b4425
remove marks, add vram skipif conditions
speediedan Feb 11, 2026
f4ad1e9
revert comment hunks to submit in a separate PR
speediedan Feb 11, 2026
787cf18
adjust `test_custom_target_correctness` to adopt our standard attribu…
speediedan Feb 12, 2026
7ea7f14
very rough exploratory draft of attribution_targets_demo.ipynb to sol…
speediedan Feb 18, 2026
be7654d
updates to attribution_targets_demo.ipynb including:
speediedan Feb 19, 2026
440f097
cleanup language and formatting, prettify with banner
speediedan Feb 20, 2026
8219073
streamlined serial backend testing with `models_cpu` fixture and `cle…
speediedan Feb 21, 2026
100af4d
restructured the demo to lead with the simpler target modes and extra…
speediedan Feb 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions circuit_tracer/attribution/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
Unified attribution interface that routes to the correct implementation based on the ReplacementModel backend.
"""

from collections.abc import Sequence
from typing import TYPE_CHECKING, Literal

import torch

from circuit_tracer.graph import Graph

if TYPE_CHECKING:
from circuit_tracer.attribution.targets import TargetSpec
from circuit_tracer.replacement_model.replacement_model_nnsight import NNSightReplacementModel
from circuit_tracer.replacement_model.replacement_model_transformerlens import (
TransformerLensReplacementModel,
Expand All @@ -19,6 +21,7 @@ def attribute(
prompt: str | torch.Tensor | list[int],
model: "NNSightReplacementModel | TransformerLensReplacementModel",
*,
attribution_targets: "Sequence[str] | Sequence[TargetSpec] | torch.Tensor | None" = None,
max_n_logits: int = 10,
desired_logit_prob: float = 0.95,
batch_size: int = 512,
Expand All @@ -35,8 +38,16 @@ def attribute(
Args:
prompt: Text, token ids, or tensor - will be tokenized if str.
model: Frozen ``ReplacementModel`` (either nnsight or transformerlens backend)
max_n_logits: Max number of logit nodes.
desired_logit_prob: Keep logits until cumulative prob >= this value.
attribution_targets: Target specification in one of four formats:
- None: Auto-select salient logits based on probability threshold
- torch.Tensor: Tensor of token indices
- Sequence[str]: Token strings (tokenized, auto-computes probability
and unembed vector)
- Sequence[TargetSpec]: Fully specified custom targets (CustomTarget or
tuple[str, float, torch.Tensor]) with arbitrary unembed directions
max_n_logits: Max number of logit nodes (used when attribution_targets is None).
desired_logit_prob: Keep logits until cumulative prob >= this value
(used when attribution_targets is None).
batch_size: How many source nodes to process per backward pass.
max_feature_nodes: Max number of feature nodes to include in the graph.
offload: Method for offloading model parameters to save memory.
Expand All @@ -55,6 +66,7 @@ def attribute(
return attribute_nnsight(
prompt=prompt,
model=model, # type: ignore[arg-type]
attribution_targets=attribution_targets,
max_n_logits=max_n_logits,
desired_logit_prob=desired_logit_prob,
batch_size=batch_size,
Expand All @@ -69,6 +81,7 @@ def attribute(
return attribute_transformerlens(
prompt=prompt,
model=model, # type: ignore[arg-type]
attribution_targets=attribution_targets,
max_n_logits=max_n_logits,
desired_logit_prob=desired_logit_prob,
batch_size=batch_size,
Expand Down
52 changes: 34 additions & 18 deletions circuit_tracer/attribution/attribute_nnsight.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,27 @@

import logging
import time
from typing import Literal
from collections.abc import Sequence
from typing import Literal, cast

import torch
from tqdm import tqdm

from circuit_tracer.attribution.targets import (
AttributionTargets,
TargetSpec,
log_attribution_target_info,
)
from circuit_tracer.graph import Graph, compute_partial_influences
from circuit_tracer.replacement_model.replacement_model_nnsight import NNSightReplacementModel
from circuit_tracer.utils.disk_offload import offload_modules
from circuit_tracer.utils.salient_logits import compute_salient_logits


def attribute(
prompt: str | torch.Tensor | list[int],
model: NNSightReplacementModel,
*,
attribution_targets: Sequence[str] | Sequence[TargetSpec] | torch.Tensor | None = None,
max_n_logits: int = 10,
desired_logit_prob: float = 0.95,
batch_size: int = 512,
Expand All @@ -45,13 +51,21 @@ def attribute(
verbose: bool = False,
update_interval: int = 4,
) -> Graph:
"""Compute an attribution graph for *prompt*.
"""Compute an attribution graph for *prompt* using NNSight backend.

Args:
prompt: Text, token ids, or tensor - will be tokenized if str.
model: Frozen ``NNSightReplacementModel``
max_n_logits: Max number of logit nodes.
desired_logit_prob: Keep logits until cumulative prob >= this value.
attribution_targets: Target specification in one of four formats:
- None: Auto-select salient logits based on probability threshold
- torch.Tensor: Tensor of token indices
- Sequence[str]: Token strings (tokenized, auto-computes probability
and unembed vector)
- Sequence[TargetSpec]: Fully specified custom targets (CustomTarget or tuple)
with arbitrary unembed directions
max_n_logits: Max number of logit nodes (used when attribution_targets is None).
desired_logit_prob: Keep logits until cumulative prob >= this value
(used when attribution_targets is None).
batch_size: How many source nodes to process per backward pass.
max_feature_nodes: Max number of feature nodes to include in the graph.
offload: Method for offloading model parameters to save memory.
Expand Down Expand Up @@ -81,6 +95,7 @@ def attribute(
return _run_attribution(
model=model,
prompt=prompt,
attribution_targets=attribution_targets,
max_n_logits=max_n_logits,
desired_logit_prob=desired_logit_prob,
batch_size=batch_size,
Expand All @@ -102,6 +117,7 @@ def attribute(
def _run_attribution(
model: NNSightReplacementModel,
prompt,
attribution_targets,
max_n_logits: int,
desired_logit_prob: float,
batch_size: int,
Expand Down Expand Up @@ -156,15 +172,17 @@ def _run_attribution(
n_layers, n_pos, _ = activation_matrix.shape
total_active_feats = activation_matrix._nnz()

logit_idx, logit_p, logit_vecs = compute_salient_logits(
ctx.logits[0, -1],
model.unembed_weight, # type: ignore
# Create AttributionTargets using NNSight's unembed_weight accessor
targets = AttributionTargets(
attribution_targets=attribution_targets,
logits=ctx.logits[0, -1],
unembed_proj=cast(torch.Tensor, model.unembed_weight), # NNSight uses unembed_weight
tokenizer=model.tokenizer,
max_n_logits=max_n_logits,
desired_logit_prob=desired_logit_prob,
)
logger.info(
f"Selected {len(logit_idx)} logits with cumulative probability {logit_p.sum().item():.4f}"
)

log_attribution_target_info(targets, attribution_targets, logger)

if offload:
offload_handles += offload_modules([model.embed_location], offload)
Expand All @@ -176,8 +194,7 @@ def _run_attribution(
offload_handles += offload_modules([model.lm_head], offload)

logit_offset = len(feat_layers) + (n_layers + 1) * n_pos
logit_offset = logit_offset
n_logits = len(logit_idx)
n_logits = len(targets)
total_nodes = logit_offset + n_logits

actual_max_feature_nodes = min(max_feature_nodes or total_active_feats, total_active_feats)
Expand All @@ -193,8 +210,8 @@ def _run_attribution(
logger.info("Phase 3: Computing logit attributions")
phase3_start = time.time()
i = -1
for i in range(0, len(logit_idx), batch_size):
batch = logit_vecs[i : i + batch_size]
for i in range(0, len(targets), batch_size):
batch = targets.logit_vectors[i : i + batch_size]
rows = ctx.compute_batch(
layers=torch.full((batch.shape[0],), n_layers),
positions=torch.full((batch.shape[0],), n_pos - 1),
Expand Down Expand Up @@ -225,7 +242,7 @@ def _run_attribution(
pending = torch.arange(total_active_feats)
else:
influences = compute_partial_influences(
edge_matrix[:st], logit_p, row_to_node_index[:st]
edge_matrix[:st], targets.logit_probabilities, row_to_node_index[:st]
)
feature_rank = torch.argsort(influences[:total_active_feats], descending=True).cpu()
queue_size = min(update_interval * batch_size, actual_max_feature_nodes - n_visited)
Expand Down Expand Up @@ -270,8 +287,7 @@ def _run_attribution(
graph = Graph(
input_string=model.tokenizer.decode(input_ids),
input_tokens=input_ids,
logit_tokens=logit_idx,
logit_probabilities=logit_p,
attribution_targets=targets,
active_features=activation_matrix.indices().T,
activation_values=activation_matrix.values(),
selected_features=selected_features,
Expand Down
50 changes: 33 additions & 17 deletions circuit_tracer/attribution/attribute_transformerlens.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,29 @@

import logging
import time
from collections.abc import Sequence
from typing import Literal

import torch
from tqdm import tqdm

from circuit_tracer.attribution.targets import (
AttributionTargets,
TargetSpec,
log_attribution_target_info,
)
from circuit_tracer.graph import Graph, compute_partial_influences
from circuit_tracer.replacement_model.replacement_model_transformerlens import (
TransformerLensReplacementModel,
)
from circuit_tracer.utils.disk_offload import offload_modules
from circuit_tracer.utils.salient_logits import compute_salient_logits


def attribute(
prompt: str | torch.Tensor | list[int],
model: TransformerLensReplacementModel,
*,
attribution_targets: Sequence[str] | Sequence[TargetSpec] | torch.Tensor | None = None,
max_n_logits: int = 10,
desired_logit_prob: float = 0.95,
batch_size: int = 512,
Expand All @@ -47,13 +53,21 @@ def attribute(
verbose: bool = False,
update_interval: int = 4,
) -> Graph:
"""Compute an attribution graph for *prompt*.
"""Compute an attribution graph for *prompt* using TransformerLens backend.

Args:
prompt: Text, token ids, or tensor - will be tokenized if str.
model: Frozen ``ReplacementModel``
max_n_logits: Max number of logit nodes.
desired_logit_prob: Keep logits until cumulative prob >= this value.
model: Frozen ``TransformerLensReplacementModel``
attribution_targets: Target specification in one of four formats:
- None: Auto-select salient logits based on probability threshold
- torch.Tensor: Tensor of token indices
- Sequence[str]: Token strings (tokenized, auto-computes probability
and unembed vector)
- Sequence[TargetSpec]: Fully specified custom targets (CustomTarget or tuple)
with arbitrary unembed directions
max_n_logits: Max number of logit nodes (used when attribution_targets is None).
desired_logit_prob: Keep logits until cumulative prob >= this value
(used when attribution_targets is None).
batch_size: How many source nodes to process per backward pass.
max_feature_nodes: Max number of feature nodes to include in the graph.
offload: Method for offloading model parameters to save memory.
Expand Down Expand Up @@ -83,6 +97,7 @@ def attribute(
return _run_attribution(
model=model,
prompt=prompt,
attribution_targets=attribution_targets,
max_n_logits=max_n_logits,
desired_logit_prob=desired_logit_prob,
batch_size=batch_size,
Expand All @@ -104,6 +119,7 @@ def attribute(
def _run_attribution(
model,
prompt,
attribution_targets,
max_n_logits,
desired_logit_prob,
batch_size,
Expand Down Expand Up @@ -147,21 +163,22 @@ def _run_attribution(
n_layers, n_pos, _ = activation_matrix.shape
total_active_feats = activation_matrix._nnz()

logit_idx, logit_p, logit_vecs = compute_salient_logits(
ctx.logits[0, -1],
model.unembed.W_U,
targets = AttributionTargets(
attribution_targets=attribution_targets,
logits=ctx.logits[0, -1],
unembed_proj=model.unembed.W_U,
tokenizer=model.tokenizer,
max_n_logits=max_n_logits,
desired_logit_prob=desired_logit_prob,
)
logger.info(
f"Selected {len(logit_idx)} logits with cumulative probability {logit_p.sum().item():.4f}"
)

log_attribution_target_info(targets, attribution_targets, logger)

if offload:
offload_handles += offload_modules([model.unembed, model.embed], offload)

logit_offset = len(feat_layers) + (n_layers + 1) * n_pos
n_logits = len(logit_idx)
n_logits = len(targets)
total_nodes = logit_offset + n_logits

max_feature_nodes = min(max_feature_nodes or total_active_feats, total_active_feats)
Expand All @@ -176,8 +193,8 @@ def _run_attribution(
# Phase 3: logit attribution
logger.info("Phase 3: Computing logit attributions")
phase_start = time.time()
for i in range(0, len(logit_idx), batch_size):
batch = logit_vecs[i : i + batch_size]
for i in range(0, len(targets), batch_size):
batch = targets.logit_vectors[i : i + batch_size]
rows = ctx.compute_batch(
layers=torch.full((batch.shape[0],), n_layers),
positions=torch.full((batch.shape[0],), n_pos - 1),
Expand All @@ -203,7 +220,7 @@ def _run_attribution(
pending = torch.arange(total_active_feats)
else:
influences = compute_partial_influences(
edge_matrix[:st], logit_p, row_to_node_index[:st]
edge_matrix[:st], targets.logit_probabilities, row_to_node_index[:st]
)
feature_rank = torch.argsort(influences[:total_active_feats], descending=True).cpu()
queue_size = min(update_interval * batch_size, max_feature_nodes - n_visited)
Expand Down Expand Up @@ -248,8 +265,7 @@ def _run_attribution(
graph = Graph(
input_string=model.tokenizer.decode(input_ids),
input_tokens=input_ids,
logit_tokens=logit_idx,
logit_probabilities=logit_p,
attribution_targets=targets,
active_features=activation_matrix.indices().T,
activation_values=activation_matrix.values(),
selected_features=selected_features,
Expand Down
Loading