Skip to content

Commit 209ed13

Browse files
authored
Use new proposed AttributionTargets API for circuit-tracer, misc infra improvements (#187)
* switch a basic ct example to use clt by default, improve infra utils and minor test expected result adjustment for pt 2.9 * use new proposed AttributionTargets API for circuit-tracer, misc infra improvements
1 parent 016b24c commit 209ed13

File tree

17 files changed

+118
-73
lines changed

17 files changed

+118
-73
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ examples = [
7878
"neuronpedia",
7979
# TODO: add our packaged circuit-tracer dep (either pypi or pypi fork) once it is available and remove the
8080
# `install_circuit_tracer` tool
81-
# "circuit-tracer @ git+https://github.com/speediedan/circuit-tracer.git@b228bf190fadb3cb30f6a5ba6691dc4c86d76ba3",
81+
# "circuit-tracer @ git+https://github.com/speediedan/circuit-tracer.git@004f1b2822eca3f0c1ddd2389e9105b3abffde87",
8282
]
8383

8484
docs = [
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
b228bf190fadb3cb30f6a5ba6691dc4c86d76ba3
1+
004f1b2822eca3f0c1ddd2389e9105b3abffde87

requirements/ci/requirements.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,4 @@ pip < 25.3
3333
huggingface_hub[hf_xet]
3434
nbmake >= 1.5.0
3535
papermill >= 2.4.0
36-
git+https://github.com/speediedan/circuit-tracer.git@b228bf190fadb3cb30f6a5ba6691dc4c86d76ba3#egg=circuit-tracer
36+
git+https://github.com/speediedan/circuit-tracer.git@004f1b2822eca3f0c1ddd2389e9105b3abffde87#egg=circuit-tracer

requirements/ci/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#
55
# pip-compile --no-strip-extras --output-file=requirements/ci/requirements.txt requirements/ci/requirements.in
66
#
7-
circuit-tracer @ git+https://github.com/speediedan/circuit-tracer.git@b228bf190fadb3cb30f6a5ba6691dc4c86d76ba3
7+
circuit-tracer @ git+https://github.com/speediedan/circuit-tracer.git@004f1b2822eca3f0c1ddd2389e9105b3abffde87
88
# via -r requirements/ci/requirements.in
99
coverage==7.11.0
1010
# via

scripts/build_it_env.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ it_install(){
195195

196196
# Verify only the editable source installation is installed
197197
echo "Verifying circuit-tracer installation..."
198-
if pip show circuit_tracer | grep -q "Editable project location:"; then
198+
if pip show circuit_tracer 2>/dev/null | grep -q "Editable project location:"; then
199199
echo "✓ circuit_tracer is installed in editable mode"
200200
else
201201
echo "✗ circuit_tracer is not installed in editable mode"

scripts/infra_utils.sh

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,5 @@ show_elapsed_time(){
1919

2020
# Function to safely deactivate a virtual environment if one is active
2121
maybe_deactivate(){
22-
if [ -n "$VIRTUAL_ENV" ]; then
23-
deactivate
24-
fi
22+
deactivate 2>/dev/null || true
2523
}

src/interpretune/adapters/circuit_tracer.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -115,30 +115,24 @@ def set_input_require_grads(self) -> None:
115115
# Circuit tracer handles gradient requirements internally
116116
rank_zero_info("Input gradient requirements handled by circuit tracer internally.")
117117

118-
def _get_analysis_target_indices(self) -> Optional[torch.Tensor]:
119-
"""Determine the value for compute_specific_logits based on CircuitTracerConfig.
118+
def _get_attribution_targets(self) -> Optional[list | torch.Tensor]:
119+
"""Determine the attribution_targets value based on CircuitTracerConfig.
120120
121-
Returns a 1D tensor of token ids, or None.
121+
Returns:
122+
- None: Auto-select salient logits (default behavior)
123+
- list[str]: Token strings to analyze (will be converted by AttributionTargets)
124+
- torch.Tensor: Tensor of token IDs
122125
"""
123126
cfg = self.circuit_tracer_cfg
124127
if not cfg:
125128
return None
126129

127-
# If analysis_target_tokens is set, tokenize them
130+
# If analysis_target_tokens is set, return as list of strings
131+
# AttributionTargets will handle tokenization internally
128132
if cfg.analysis_target_tokens is not None:
129-
tokenizer = self.datamodule.tokenizer if self.datamodule else self.it_cfg.tokenizer
130-
# Tokenize and flatten to 1D tensor of token ids
131-
token_ids = []
132-
for token in cfg.analysis_target_tokens:
133-
assert tokenizer is not None, "Tokenizer must be available to tokenize analysis_target_tokens"
134-
ids = tokenizer.encode(token, add_special_tokens=False)
135-
token_ids.extend(ids)
136-
if token_ids:
137-
return torch.tensor(token_ids, dtype=torch.long)
138-
else:
139-
return None
133+
return cfg.analysis_target_tokens
140134

141-
# If target_token_ids is set
135+
# If target_token_ids is set, process it
142136
if cfg.target_token_ids is not None:
143137
ids = cfg.target_token_ids
144138
if isinstance(ids, torch.Tensor):
@@ -155,7 +149,7 @@ def _get_analysis_target_indices(self) -> Optional[torch.Tensor]:
155149
else:
156150
return None
157151

158-
# If neither is set, return None
152+
# If neither is set, return None (use salient logits)
159153
return None
160154

161155
def generate_attribution_graph(self, prompt: str, **kwargs) -> Graph:
@@ -165,18 +159,18 @@ def generate_attribution_graph(self, prompt: str, **kwargs) -> Graph:
165159

166160
cfg = self.circuit_tracer_cfg
167161

168-
# Determine compute_specific_logits using the new method
169-
analysis_target_indices = self._get_analysis_target_indices()
162+
# Determine attribution_targets using the new method
163+
attribution_targets = self._get_attribution_targets()
170164

171165
# Set default attribution parameters
172166
attribution_kwargs = {
167+
"attribution_targets": attribution_targets,
173168
"max_n_logits": cfg.max_n_logits if cfg else 10,
174169
"desired_logit_prob": cfg.desired_logit_prob if cfg else 0.95,
175170
"batch_size": cfg.batch_size if cfg else 256,
176171
"max_feature_nodes": cfg.max_feature_nodes if cfg else None,
177172
"offload": cfg.offload if cfg else None,
178173
"verbose": cfg.verbose if cfg else True,
179-
"analysis_target_indices": analysis_target_indices,
180174
}
181175

182176
# Override with any provided kwargs

src/it_examples/notebooks/dev/attribution_analysis/analysis_points.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,48 +68,49 @@ def ap_build_input_vectors_end(local_vars: Dict[str, Any]) -> None:
6868
v = get_analysis_vars(
6969
context_keys=["target_token_analysis"],
7070
local_keys=[
71-
"logit_idx",
72-
"logit_p",
71+
"targets",
7372
"total_nodes",
7473
"total_active_feats",
7574
"max_feature_nodes",
7675
"edge_matrix",
7776
"row_to_node_index",
78-
"logit_vecs",
7977
"n_layers",
8078
"n_pos",
8179
],
8280
local_vars=local_vars,
8381
)
8482
tta = v["target_token_analysis"]
85-
tta.update_logit_info(v["logit_idx"], v["logit_p"])
86-
max_n_logits = len(v["logit_idx"])
83+
targets = v["targets"]
84+
# Extract token IDs and probabilities from targets object
85+
logit_idx, logit_p = targets.token_ids, targets.logit_probabilities
86+
tta.update_logit_info(logit_idx, logit_p)
87+
max_n_logits = len(targets)
8788

8889
HOOK_REGISTRY.set_context(
8990
max_feature_nodes=v["max_feature_nodes"],
9091
total_nodes=v["total_nodes"],
91-
logit_idx=v["logit_idx"],
92-
logit_p=v["logit_p"],
92+
logit_idx=logit_idx,
93+
logit_p=logit_p,
9394
n_pos=v["n_pos"],
9495
max_n_logits=max_n_logits,
9596
)
9697
data = {
97-
"logit_idx": v["logit_idx"],
98-
"logit_p": VarAnnotate("logit_p", var_value=v["logit_p"], annotation="non-demeaned logits probabilities"),
98+
"logit_idx": logit_idx,
99+
"logit_p": VarAnnotate("logit_p", var_value=logit_p, annotation="logit probabilities"),
99100
"target_tokens": tta.tokens,
100101
"target_logit_indices": tta.logit_indices,
101102
"target_logit_p": tta.logit_probabilities,
102-
"logit_cumulative_prob": float(v["logit_p"].sum().item()),
103+
"logit_cumulative_prob": float(logit_p.sum().item()),
103104
"total_nodes": v["total_nodes"],
104105
"max_feature_nodes": v["max_feature_nodes"],
105106
"total_active_feats": v["total_active_feats"],
106-
"n_logits": len(v["logit_idx"]),
107+
"n_logits": len(targets),
107108
"n_layers": v["n_layers"],
108109
"n_pos": v["n_pos"],
109110
"max_n_logits": max_n_logits,
110111
"edge_matrix.shape": v["edge_matrix"].shape,
111112
"row_to_node_index.shape": v["row_to_node_index"].shape if v["row_to_node_index"] is not None else None,
112-
"logit_vecs.shape": v["logit_vecs"].shape,
113+
"logit_vecs.shape": targets.logit_vectors.shape,
113114
}
114115
analysis_log_point("after building input vectors w/ target logits", data)
115116

@@ -151,7 +152,7 @@ def ap_compute_feature_attributions_end(local_vars: Dict[str, Any]) -> None:
151152
# Use dict directly for cleaner access
152153
v = get_analysis_vars(
153154
context_keys=["target_token_analysis"],
154-
local_keys=["n_visited", "max_feature_nodes", "logit_p", "edge_matrix", "ctx"],
155+
local_keys=["n_visited", "max_feature_nodes", "targets", "edge_matrix", "ctx"],
155156
local_vars=local_vars,
156157
)
157158
tta = v["target_token_analysis"]

src/it_examples/notebooks/dev/attribution_analysis/attribution_analysis.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
"if should_install_package(\"circuit-tracer\"):\n",
6262
" # Note: Using line continuation for long git URL\n",
6363
" !python -m pip install \\\n",
64-
" 'git+https://github.com/speediedan/circuit-tracer.git@b228bf190fadb3cb30f6a5ba6691dc4c86d76ba3'"
64+
" 'git+https://github.com/speediedan/circuit-tracer.git@004f1b2822eca3f0c1ddd2389e9105b3abffde87'"
6565
]
6666
},
6767
{

src/it_examples/notebooks/dev/circuit_tracer_examples/circuit_tracer_adapter_example_basic_clt.ipynb

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@
107107
"source": [
108108
"# Parameters - These will be injected by papermill during parameterized test runs\n",
109109
"use_baseline_salient_logits = True # logits computation mode: True->salient logits, False->specific logits\n",
110-
"use_baseline_transcoder_arch = True # transcoder architecture: True->SingleLayerTranscoder, False->CrossLayerTranscoder\n",
110+
"use_baseline_transcoder_arch = (\n",
111+
" False # transcoder architecture: True->SingleLayerTranscoder, False->CrossLayerTranscoder\n",
112+
")\n",
111113
"core_log_dir = None # Directory to save analysis logs (if None, a temp directory will be created)"
112114
]
113115
},
@@ -347,11 +349,13 @@
347349
"if port_forwarding:\n",
348350
" hostname = \"localhost\" # use localhost for port forwarding\n",
349351
" print(\n",
350-
" f\"Using port forwarding (ensure it is configured) and localhost. Open your graph here at http://{hostname}:{port}/index.html\"\n",
352+
" f\"Using port forwarding (ensure it is configured) and localhost.\"\n",
353+
" f\" Open your graph here at http://{hostname}:{port}/index.html\"\n",
351354
" )\n",
352355
"else:\n",
353356
" print(\n",
354-
" f\"Not using port forwarding. Use the IFrame below, or open your graph here directly at http://{hostname}:{port}/index.html\"\n",
357+
" f\"Not using port forwarding. Use the IFrame below, or\"\n",
358+
" f\" open your graph here directly at http://{hostname}:{port}/index.html\"\n",
355359
" )\n",
356360
"\n",
357361
"if enable_iframe:\n",

0 commit comments

Comments
 (0)