Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ examples = [
"neuronpedia",
# TODO: add our packaged circuit-tracer dep (either pypi or pypi fork) once it is available and remove the
# `install_circuit_tracer` tool
# "circuit-tracer @ git+https://github.com/speediedan/circuit-tracer.git@b228bf190fadb3cb30f6a5ba6691dc4c86d76ba3",
# "circuit-tracer @ git+https://github.com/speediedan/circuit-tracer.git@004f1b2822eca3f0c1ddd2389e9105b3abffde87",
]

docs = [
Expand Down
2 changes: 1 addition & 1 deletion requirements/ci/circuit_tracer_pin.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
b228bf190fadb3cb30f6a5ba6691dc4c86d76ba3
004f1b2822eca3f0c1ddd2389e9105b3abffde87
2 changes: 1 addition & 1 deletion requirements/ci/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ pip < 25.3
huggingface_hub[hf_xet]
nbmake >= 1.5.0
papermill >= 2.4.0
git+https://github.com/speediedan/circuit-tracer.git@b228bf190fadb3cb30f6a5ba6691dc4c86d76ba3#egg=circuit-tracer
git+https://github.com/speediedan/circuit-tracer.git@004f1b2822eca3f0c1ddd2389e9105b3abffde87#egg=circuit-tracer
2 changes: 1 addition & 1 deletion requirements/ci/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#
# pip-compile --no-strip-extras --output-file=requirements/ci/requirements.txt requirements/ci/requirements.in
#
circuit-tracer @ git+https://github.com/speediedan/circuit-tracer.git@b228bf190fadb3cb30f6a5ba6691dc4c86d76ba3
circuit-tracer @ git+https://github.com/speediedan/circuit-tracer.git@004f1b2822eca3f0c1ddd2389e9105b3abffde87
# via -r requirements/ci/requirements.in
coverage==7.11.0
# via
Expand Down
2 changes: 1 addition & 1 deletion scripts/build_it_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ it_install(){

# Verify only the editable source installation is installed
echo "Verifying circuit-tracer installation..."
if pip show circuit_tracer | grep -q "Editable project location:"; then
if pip show circuit_tracer 2>/dev/null | grep -q "Editable project location:"; then
echo "✓ circuit_tracer is installed in editable mode"
else
echo "✗ circuit_tracer is not installed in editable mode"
Expand Down
4 changes: 1 addition & 3 deletions scripts/infra_utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,5 @@ show_elapsed_time(){

# Function to safely deactivate a virtual environment if one is active
maybe_deactivate(){
if [ -n "$VIRTUAL_ENV" ]; then
deactivate
fi
deactivate 2>/dev/null || true
}
34 changes: 14 additions & 20 deletions src/interpretune/adapters/circuit_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,30 +115,24 @@ def set_input_require_grads(self) -> None:
# Circuit tracer handles gradient requirements internally
rank_zero_info("Input gradient requirements handled by circuit tracer internally.")

def _get_analysis_target_indices(self) -> Optional[torch.Tensor]:
"""Determine the value for compute_specific_logits based on CircuitTracerConfig.
def _get_attribution_targets(self) -> Optional[list | torch.Tensor]:
"""Determine the attribution_targets value based on CircuitTracerConfig.

Returns a 1D tensor of token ids, or None.
Returns:
- None: Auto-select salient logits (default behavior)
- list[str]: Token strings to analyze (will be converted by AttributionTargets)
- torch.Tensor: Tensor of token IDs
"""
cfg = self.circuit_tracer_cfg
if not cfg:
return None

# If analysis_target_tokens is set, tokenize them
# If analysis_target_tokens is set, return as list of strings
# AttributionTargets will handle tokenization internally
if cfg.analysis_target_tokens is not None:
tokenizer = self.datamodule.tokenizer if self.datamodule else self.it_cfg.tokenizer
# Tokenize and flatten to 1D tensor of token ids
token_ids = []
for token in cfg.analysis_target_tokens:
assert tokenizer is not None, "Tokenizer must be available to tokenize analysis_target_tokens"
ids = tokenizer.encode(token, add_special_tokens=False)
token_ids.extend(ids)
if token_ids:
return torch.tensor(token_ids, dtype=torch.long)
else:
return None
return cfg.analysis_target_tokens

# If target_token_ids is set
# If target_token_ids is set, process it
if cfg.target_token_ids is not None:
ids = cfg.target_token_ids
if isinstance(ids, torch.Tensor):
Expand All @@ -155,7 +149,7 @@ def _get_analysis_target_indices(self) -> Optional[torch.Tensor]:
else:
return None

# If neither is set, return None
# If neither is set, return None (use salient logits)
return None

def generate_attribution_graph(self, prompt: str, **kwargs) -> Graph:
Expand All @@ -165,18 +159,18 @@ def generate_attribution_graph(self, prompt: str, **kwargs) -> Graph:

cfg = self.circuit_tracer_cfg

# Determine compute_specific_logits using the new method
analysis_target_indices = self._get_analysis_target_indices()
# Determine attribution_targets using the new method
attribution_targets = self._get_attribution_targets()

# Set default attribution parameters
attribution_kwargs = {
"attribution_targets": attribution_targets,
"max_n_logits": cfg.max_n_logits if cfg else 10,
"desired_logit_prob": cfg.desired_logit_prob if cfg else 0.95,
"batch_size": cfg.batch_size if cfg else 256,
"max_feature_nodes": cfg.max_feature_nodes if cfg else None,
"offload": cfg.offload if cfg else None,
"verbose": cfg.verbose if cfg else True,
"analysis_target_indices": analysis_target_indices,
}

# Override with any provided kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,48 +68,49 @@ def ap_build_input_vectors_end(local_vars: Dict[str, Any]) -> None:
v = get_analysis_vars(
context_keys=["target_token_analysis"],
local_keys=[
"logit_idx",
"logit_p",
"targets",
"total_nodes",
"total_active_feats",
"max_feature_nodes",
"edge_matrix",
"row_to_node_index",
"logit_vecs",
"n_layers",
"n_pos",
],
local_vars=local_vars,
)
tta = v["target_token_analysis"]
tta.update_logit_info(v["logit_idx"], v["logit_p"])
max_n_logits = len(v["logit_idx"])
targets = v["targets"]
# Extract token IDs and probabilities from targets object
logit_idx, logit_p = targets.token_ids, targets.logit_probabilities
tta.update_logit_info(logit_idx, logit_p)
max_n_logits = len(targets)

HOOK_REGISTRY.set_context(
max_feature_nodes=v["max_feature_nodes"],
total_nodes=v["total_nodes"],
logit_idx=v["logit_idx"],
logit_p=v["logit_p"],
logit_idx=logit_idx,
logit_p=logit_p,
n_pos=v["n_pos"],
max_n_logits=max_n_logits,
)
data = {
"logit_idx": v["logit_idx"],
"logit_p": VarAnnotate("logit_p", var_value=v["logit_p"], annotation="non-demeaned logits probabilities"),
"logit_idx": logit_idx,
"logit_p": VarAnnotate("logit_p", var_value=logit_p, annotation="logit probabilities"),
"target_tokens": tta.tokens,
"target_logit_indices": tta.logit_indices,
"target_logit_p": tta.logit_probabilities,
"logit_cumulative_prob": float(v["logit_p"].sum().item()),
"logit_cumulative_prob": float(logit_p.sum().item()),
"total_nodes": v["total_nodes"],
"max_feature_nodes": v["max_feature_nodes"],
"total_active_feats": v["total_active_feats"],
"n_logits": len(v["logit_idx"]),
"n_logits": len(targets),
"n_layers": v["n_layers"],
"n_pos": v["n_pos"],
"max_n_logits": max_n_logits,
"edge_matrix.shape": v["edge_matrix"].shape,
"row_to_node_index.shape": v["row_to_node_index"].shape if v["row_to_node_index"] is not None else None,
"logit_vecs.shape": v["logit_vecs"].shape,
"logit_vecs.shape": targets.logit_vectors.shape,
}
analysis_log_point("after building input vectors w/ target logits", data)

Expand Down Expand Up @@ -151,7 +152,7 @@ def ap_compute_feature_attributions_end(local_vars: Dict[str, Any]) -> None:
# Use dict directly for cleaner access
v = get_analysis_vars(
context_keys=["target_token_analysis"],
local_keys=["n_visited", "max_feature_nodes", "logit_p", "edge_matrix", "ctx"],
local_keys=["n_visited", "max_feature_nodes", "targets", "edge_matrix", "ctx"],
local_vars=local_vars,
)
tta = v["target_token_analysis"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
"if should_install_package(\"circuit-tracer\"):\n",
" # Note: Using line continuation for long git URL\n",
" !python -m pip install \\\n",
" 'git+https://github.com/speediedan/circuit-tracer.git@b228bf190fadb3cb30f6a5ba6691dc4c86d76ba3'"
" 'git+https://github.com/speediedan/circuit-tracer.git@004f1b2822eca3f0c1ddd2389e9105b3abffde87'"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@
"source": [
"# Parameters - These will be injected by papermill during parameterized test runs\n",
"use_baseline_salient_logits = True # logits computation mode: True->salient logits, False->specific logits\n",
"use_baseline_transcoder_arch = True # transcoder architecture: True->SingleLayerTranscoder, False->CrossLayerTranscoder\n",
"use_baseline_transcoder_arch = (\n",
" False # transcoder architecture: True->SingleLayerTranscoder, False->CrossLayerTranscoder\n",
")\n",
"core_log_dir = None # Directory to save analysis logs (if None, a temp directory will be created)"
]
},
Expand Down Expand Up @@ -347,11 +349,13 @@
"if port_forwarding:\n",
" hostname = \"localhost\" # use localhost for port forwarding\n",
" print(\n",
" f\"Using port forwarding (ensure it is configured) and localhost. Open your graph here at http://{hostname}:{port}/index.html\"\n",
" f\"Using port forwarding (ensure it is configured) and localhost.\"\n",
" f\" Open your graph here at http://{hostname}:{port}/index.html\"\n",
" )\n",
"else:\n",
" print(\n",
" f\"Not using port forwarding. Use the IFrame below, or open your graph here directly at http://{hostname}:{port}/index.html\"\n",
" f\"Not using port forwarding. Use the IFrame below, or\"\n",
" f\" open your graph here directly at http://{hostname}:{port}/index.html\"\n",
" )\n",
"\n",
"if enable_iframe:\n",
Expand Down
6 changes: 3 additions & 3 deletions src/it_examples/notebooks/publish/.notebook_hashes.json
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
{
"notebooks/dev/attribution_analysis/analysis_injection_config.yaml": "0153616832195ab9adfc9008cdf7f26d73426128f9fe01b9a614929018e1a53a",
"notebooks/dev/attribution_analysis/analysis_points.py": "a74256e26f10d00b4e0abeeb69f3bab7332357dc19ab006fa56afd1d0502db94",
"notebooks/dev/attribution_analysis/attribution_analysis.ipynb": "f70568bf095c184fd50123560e76e1ee4f0cb72cbd1529c841c993680b02fe54",
"notebooks/dev/attribution_analysis/analysis_points.py": "b27a7eb3d6cf02194ed776b4d7e9a885c2ebb563aac73822e3d21b0366c3e8a0",
"notebooks/dev/attribution_analysis/attribution_analysis.ipynb": "89d1509a99ba14a391dbde18d4bda68b276551b52eb4207942b7308f06980a00",
"notebooks/dev/circuit_tracer_examples/circuit_tracer_adapter_example_basic.ipynb": "760c051bfb3932815e400918724661e46ef49fbe10a4e9231442653707dc83c9",
"notebooks/dev/circuit_tracer_examples/circuit_tracer_adapter_example_basic_clt.ipynb": "c74595c292c5964558943f3a2d8786e8e3e75425ab56eb48d112518d9133496e",
"notebooks/dev/circuit_tracer_examples/circuit_tracer_adapter_example_basic_clt.ipynb": "c101e6594dd34cc55c7dd46439e316a8b4a4fff8d7afc1edb9b5bba88156ab0a",
"notebooks/dev/circuit_tracer_examples/gradient_flow_analysis.ipynb": "21818d3fa9765b29b84d38d977a744c185d89fced8913893147d61600c124b45",
"notebooks/dev/example_op_collections/hub_op_collection/hub_op_collection.yaml": "93d81adff9eb7380500df6e598a712d7ac1b73c54a831bdfbfaa05ed701b244b",
"notebooks/dev/example_op_collections/hub_op_collection/hub_op_definitions.py": "ebcc238e3c0cefc94dba7963197f5a6d81a28158f5f13b1bb826859dd5475f60",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,48 +68,49 @@ def ap_build_input_vectors_end(local_vars: Dict[str, Any]) -> None:
v = get_analysis_vars(
context_keys=["target_token_analysis"],
local_keys=[
"logit_idx",
"logit_p",
"targets",
"total_nodes",
"total_active_feats",
"max_feature_nodes",
"edge_matrix",
"row_to_node_index",
"logit_vecs",
"n_layers",
"n_pos",
],
local_vars=local_vars,
)
tta = v["target_token_analysis"]
tta.update_logit_info(v["logit_idx"], v["logit_p"])
max_n_logits = len(v["logit_idx"])
targets = v["targets"]
# Extract token IDs and probabilities from targets object
logit_idx, logit_p = targets.token_ids, targets.logit_probabilities
tta.update_logit_info(logit_idx, logit_p)
max_n_logits = len(targets)

HOOK_REGISTRY.set_context(
max_feature_nodes=v["max_feature_nodes"],
total_nodes=v["total_nodes"],
logit_idx=v["logit_idx"],
logit_p=v["logit_p"],
logit_idx=logit_idx,
logit_p=logit_p,
n_pos=v["n_pos"],
max_n_logits=max_n_logits,
)
data = {
"logit_idx": v["logit_idx"],
"logit_p": VarAnnotate("logit_p", var_value=v["logit_p"], annotation="non-demeaned logits probabilities"),
"logit_idx": logit_idx,
"logit_p": VarAnnotate("logit_p", var_value=logit_p, annotation="logit probabilities"),
"target_tokens": tta.tokens,
"target_logit_indices": tta.logit_indices,
"target_logit_p": tta.logit_probabilities,
"logit_cumulative_prob": float(v["logit_p"].sum().item()),
"logit_cumulative_prob": float(logit_p.sum().item()),
"total_nodes": v["total_nodes"],
"max_feature_nodes": v["max_feature_nodes"],
"total_active_feats": v["total_active_feats"],
"n_logits": len(v["logit_idx"]),
"n_logits": len(targets),
"n_layers": v["n_layers"],
"n_pos": v["n_pos"],
"max_n_logits": max_n_logits,
"edge_matrix.shape": v["edge_matrix"].shape,
"row_to_node_index.shape": v["row_to_node_index"].shape if v["row_to_node_index"] is not None else None,
"logit_vecs.shape": v["logit_vecs"].shape,
"logit_vecs.shape": targets.logit_vectors.shape,
}
analysis_log_point("after building input vectors w/ target logits", data)

Expand Down Expand Up @@ -151,7 +152,7 @@ def ap_compute_feature_attributions_end(local_vars: Dict[str, Any]) -> None:
# Use dict directly for cleaner access
v = get_analysis_vars(
context_keys=["target_token_analysis"],
local_keys=["n_visited", "max_feature_nodes", "logit_p", "edge_matrix", "ctx"],
local_keys=["n_visited", "max_feature_nodes", "targets", "edge_matrix", "ctx"],
local_vars=local_vars,
)
tta = v["target_token_analysis"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
"if should_install_package(\"circuit-tracer\"):\n",
" # Note: Using line continuation for long git URL\n",
" !python -m pip install \\\n",
" 'git+https://github.com/speediedan/circuit-tracer.git@b228bf190fadb3cb30f6a5ba6691dc4c86d76ba3'"
" 'git+https://github.com/speediedan/circuit-tracer.git@004f1b2822eca3f0c1ddd2389e9105b3abffde87'"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@
"source": [
"# Parameters - These will be injected by papermill during parameterized test runs\n",
"use_baseline_salient_logits = True # logits computation mode: True->salient logits, False->specific logits\n",
"use_baseline_transcoder_arch = True # transcoder architecture: True->SingleLayerTranscoder, False->CrossLayerTranscoder\n",
"use_baseline_transcoder_arch = (\n",
" False # transcoder architecture: True->SingleLayerTranscoder, False->CrossLayerTranscoder\n",
")\n",
"core_log_dir = None # Directory to save analysis logs (if None, a temp directory will be created)"
]
},
Expand Down Expand Up @@ -364,11 +366,13 @@
"if port_forwarding:\n",
" hostname = \"localhost\" # use localhost for port forwarding\n",
" print(\n",
" f\"Using port forwarding (ensure it is configured) and localhost. Open your graph here at http://{hostname}:{port}/index.html\"\n",
" f\"Using port forwarding (ensure it is configured) and localhost.\"\n",
" f\" Open your graph here at http://{hostname}:{port}/index.html\"\n",
" )\n",
"else:\n",
" print(\n",
" f\"Not using port forwarding. Use the IFrame below, or open your graph here directly at http://{hostname}:{port}/index.html\"\n",
" f\"Not using port forwarding. Use the IFrame below, or\"\n",
" f\" open your graph here directly at http://{hostname}:{port}/index.html\"\n",
" )\n",
"\n",
"if enable_iframe:\n",
Expand Down
Loading
Loading