Attribution Targets Encapsulation#52
Conversation
Squashed changes from arbitrary-logit-attribution branch.
…he functionality proposed in PR decoderesearch#44 and provides a unified foundation for future attribution target enhancements.
392ab1b to
b6bf15f
Compare
There was a problem hiding this comment.
Pull Request Overview
This PR refactors the attribution system to support flexible logit target specification, including arbitrary string tokens and functions thereof that may not be in the model's vocabulary. The changes introduce a new AttributionTargets container class and LogitTarget data structure while maintaining backward compatibility with the legacy tensor format.
Key Changes
- Introduces
AttributionTargetsclass to handle multiple target specification formats (None for auto-selection, torch.Tensor for specific token IDs, or mixed lists including arbitrary strings) - Replaces
Graph.logit_tokens(torch.Tensor) withGraph.logit_targets(list[LogitTarget]) to support both real vocabulary indices and virtual indices for out-of-vocabulary tokens - Implements virtual index mechanism (vocab_size + position) to uniquely identify arbitrary string tokens not in the model's vocabulary
Reviewed Changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| circuit_tracer/attribution/targets.py | New module introducing AttributionTargets container and LogitTarget NamedTuple for flexible target specification |
| circuit_tracer/graph.py | Refactored Graph class to use LogitTarget list instead of tensor, added properties for backward compatibility and virtual index handling |
| circuit_tracer/attribution/attribute.py | Updated attribution function to use AttributionTargets, removed deprecated compute_salient_logits function |
| circuit_tracer/utils/create_graph_files.py | Updated node creation to unpack LogitTarget tuples for token strings and vocabulary indices |
| tests/test_attribution_targets.py | Comprehensive test suite for new AttributionTargets class covering all input formats and edge cases |
| tests/test_graph.py | Added tests for Graph serialization/deserialization with both tensor and LogitTarget formats |
| tests/test_attributions_gemma.py | Updated to use new logit_token_ids property instead of deprecated logit_tokens |
| tests/test_attribution_clt.py | Updated to use new logit_token_ids property |
| demos/attribute_demo.ipynb | Updated documentation to reference logit_targets instead of logit_tokens |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Hi, I just wanted to leave a comment and let you know I've seen this PR and the others - thanks for contributing! I have been away for the past week (and will be for the next), but hope to look at this soon. |
No worries at all, thanks for letting me know. Looking forward to discussing then! 🚀 |
|
@hannamw Any thoughts/feedback on this proposed enhancement? I'm sure you've got an immense amount on your plate, just wanted to ensure it was still in the queue somewhere or if not, if there were amendments that could be made to otherwise allow the proposed functionality. The world model analysis framework I'm building will depend on the arbitrary logit attribution target specification you seemed to be exploring in #44, which was closely related to enhancements I'd already been using on the circuit-tracer fork for my pre-MVP package. I think the enhanced attribution target definition/resolution interface could be broadly useful as I can foresee an increasing volume and variety of use cases on the horizon for your already immensely valuable package! Thanks again for all your work! 🚀 🎉 |
…ing cases (decoderesearch#51) * fix: handle single module in offload_modules Fix TypeError when passing a single module (e.g., CrossLayerTranscoder) to offload_modules instead of a list. Now properly handles single modules, lists, and PyTorch container types (ModuleList, ModuleDict, Sequential). * minor aesthetic change, slight simplification of logic
…e range of gpus the test suite runs with (decoderesearch#50)
…PR for new multi-backend arch
…ch sizes to allow running more tests with minimum CI hardware profile
…entation in AttributionTargets
|
Hi! Sorry to take so long to get to this - I was working on the nnsight backend, and didn't want to merge anything large until then. I noticed that you updated this with nnsight attribution; is this ready for review? If so, I'll give it a look! |
No worries @hannamw, I saw you had a substantial lift with that work and expected that was the case. Nice work on the nnsight addition and multi-backend architectural refactor btw! This PR is ready for re-review, thanks! |
hannamw
left a comment
There was a problem hiding this comment.
Thanks a bunch for this big PR! I really appreciate the work that you've put into it. I added comments throughout; let me know what you think. I also think it'd be great to add a demo notebook showcasing this new feature. Thanks again for implementing this!
| model: "NNSightReplacementModel | TransformerLensReplacementModel", | ||
| *, | ||
| attribution_targets: ( | ||
| Sequence[tuple[str, float, torch.Tensor] | int | str] | torch.Tensor | None |
There was a problem hiding this comment.
Might be better here to just commit to a format! i.e. input either None, a Sequence of strs, or a sequence of fully specified attribution targets. This cuts down on the number of cases / potential confusion.
There was a problem hiding this comment.
Makes sense! Supporting the mixed-mode input is probably not worth the marginal complexity and potential user confusion. I've dropped the mixed-mode list[int | str | tuple] entirely.
The revised interface for attribution_targets I'm proposing now is:
Sequence[str] | Sequence[TargetSpec] | torch.Tensor | None
Sequence[str]— token strings (AttributionTargets handles tokenization, we warn on multi-token strings and take the final token, lmk if you prefer we error instead or only take the final token if a bool flag is set)Sequence[TargetSpec]— fully specified custom targets with explicit token_str, probability, and unembed vectors. I thought it was worth making the target specification a named tuple for clarity. We still accept regular tuples and create named tuples for the user if preferred but keep the signature readable with this alias:
TargetSpec = CustomTarget | tuple[str, float, torch.Tensor]torch.Tensor— My reasoning in accepting thetorch.Tensorinput is that for it's convenient for downstream consumers that have pre-tokenized IDs they'd like to attribute without roundtripping through strings. The world model analysis framework I'm building does this, invoking circuit-tracer'sattributevia composable operations with pre-tokenized tensor inputs. If you feel strongly about this we can drop it and require users to convert toSequence[str]orSequence[TargetSpec]but I think it's a nice-to-have to allow greater breadth of downstream use cases and doesn't add much complexity.None— auto-select salient logits
There was a problem hiding this comment.
This sounds okay to me - I recognize that having multiple input formats could be convenient (even though I would like to cut down on the number of formats). Re: warning vs. erroring, I'd prefer to error here. If anything, I think I'd want to take the first token, but erroring still seems best.
There was a problem hiding this comment.
Makes sense! Changed our multi-token string handling from warnings.warn() to raise ValueError().
circuit_tracer/graph.py
Outdated
| self.logit_targets = attribution_targets.logit_targets | ||
| self.logit_probabilities = attribution_targets.logit_probabilities | ||
| self.vocab_size = attribution_targets.vocab_size | ||
| elif logit_targets is not None and logit_probabilities is not None: |
There was a problem hiding this comment.
Is this for backwards compatibility? I don't mind breaking this, as long as this is clearly indicated in the new release. Or perhaps more elegant would be to move this into from_pt
There was a problem hiding this comment.
Yep, this was for backward compatibility with serialized graphs. Agreed that from_pt is the right place for this — moved the tensor-to-LogitTarget conversion there. Graph.__init__ now only accepts attribution_targets (the container) or logit_targets as list[LogitTarget] directly, making the constructor cleaner.
There was a problem hiding this comment.
Thanks! One more request - can we make it so that only one of (attribution_targets or (logit_targets and logit_probabilities) is valid? It would be cleaner to have just one option. I guess I would prefer to get rid of the attribution_targets option and just access its attributes when creating the Graph in attribute, seeing as we don't actually use attribution_targets itself.
There was a problem hiding this comment.
Makes sense! The AttributionTargets container really is an attribution-phase concern, not something Graph needs to know about. I removed the attribution_targets parameter entirely and made logit_targets and logit_probabilities required positional parameters. attribute_transformerlens.py and attribute_nnsight.py now unpack the relevant attributes from the AttributionTargets instance.
tests/test_attribution_targets.py
Outdated
| @@ -0,0 +1,384 @@ | |||
| """Unit tests for AttributionTargets class.""" | |||
There was a problem hiding this comment.
Would be nice to have a test that verifies the correctness of the List[tuple] formatted attribution targets. For example, attribute back from logit(x) - logit(y) (i.e. the difference in the corresponding unembed layers), and verify that, upon intervention, the difference in their logits changes as expected
There was a problem hiding this comment.
One could also attribute back from None and its equivalent token lists / attribution target tuples and check for consistency
There was a problem hiding this comment.
Great suggestions — I thought it'd make sense to add these as integrations tests, 2 per backend:
-
test_custom_target_correctness— constructs alogit(x) - logit(y)direction from the unembed matrix, runs attribution with this as aCustomTarget, then ablates/amplifies the top features and verifying the logit difference changes in the expected direction. Both ablation and amplification produce pronounced effects in the expected directions. -
test_attribution_format_consistency— runs attribution withNone(auto-select), then with equivalentSequence[str]andSequence[CustomTarget]targets, and verifies all three produce identical results: same targets, probabilities, and adjacency matrices.
All four tests use the real gemma-2-2b model. The TL and NNSight variants share backend-agnostic logic via helper functions.
There was a problem hiding this comment.
Thanks a bunch for adding these! For 1., could we instead do an exact check? If the direction is indeed logit(x) - logit(y), then we should be able to check that the change in logit(x) - logit(y) is precisely what our attribution graph predicts.
There was a problem hiding this comment.
Done! I updated test_custom_target_correctness to adopt the same attribution validation approach used elsewhere in the codebase (e.g. test_attributions_gemma.py::test_gemma_2_2b() verify_feature_edges).
Note that I needed to set logit_atol to be less tight than in verify_feature_edges (1e-4 vs 1e-5) but the other 3 tolerance thresholds are identical. I think this is probably defensible given the combined effects of cancellation (increased low-order bit dependency) and asymmetric accumulation bias (baseline logits of 39 vs 31 create a downward bias for the 39 logit). Empirically, all 50 randomly-sampled features passed at 1e-4 while only ~50% passed at 1e-5. Let me know if you want to try upcasting to float64 or think we need to look other strategies to get tighter tolerances here.
… review feedback, includes comment and mark changes that will be separated into a separate PR
Thanks again for the thorough review @hannamw ! I've addressed all of your feedback in the latest commits and extracted the comment change hunks into a separate PR that I'll open in order to keep the focus of this PR on the attribution targets encapsulation. Great suggestion on the demo notebook. Is it okay if I add that in a separate PR as soon as I get a chance in the next few days? |
hannamw
left a comment
There was a problem hiding this comment.
Thanks again for these changes! I've left just a few responses in some of the threads, but things are otherwise looking good. I'm looking forward to seeing the demo notebook and trying it out for myself!
…tion validation approach, remove attribution_targets ref from `Graph`, remove unnecessary vocab_indices properties, convert multi-token string detection handling from warning to error
…icit maintainer/community feedback
- refined initial description and overview of the attribution targets API - updated colab notebook setup to allow testing the attribution-targets branch before merging - improved formatting of the attribution target, token mapping outputs and top feature comparison tables - verified the demo notebook is working as expected in colab
…an_cuda` context manager to allow simple orchestration of sequential backend parity tests. Useful for VRAM constrained environments or to enable parity testing of larger models that may not fit on GPU simultaneously.
|
Just had a chance to look at this - the notebook looks good! The only change that I would suggest is to reorder it a bit, starting with the simpler types of attribution targets (None, then List[str] and torch.Tensor). I think this will be less confusing to readers, who might not want to immediately jump into diff targets or semantic concept targets. |
…cted the `CustomTarget` examples and helper functions discussion to a distinct section. added a torch.Tensor version of the `Sequence[str]` example for completeness.
Good suggestion! I've restructured the demo to lead with the simpler target modes and extracted the I've also updated the |
hannamw
left a comment
There was a problem hiding this comment.
Thanks for all of these changes! I think this is ready to be merged.
Overview
This PR proposes an enhancement (
AttributionTargets) that encapsulates the functionality proposed in PR #44 and provides a unified foundation for future attribution target extensions.Building upon the functionality proposed in PR #44, it encapsulates attribution target definition/resolution while extending configuration flexibility. The defined interface should robustly accommodate future attribution target features and shield users from necessary internal refactors.
Key constructs:
token_str,vocab_idx)token_idtensor driven filtering or other access patterns using views of the same underlying data structure.Key Changes
1.
AttributionTargets(targets.py)New module introducing two key constructs:
LogitTarget (NamedTuple)
A lightweight data transfer object for storing token metadata:
token_str: Human-readable token representation (decoded from vocabulary or arbitrary string)vocab_idx: Either a real token ID (< vocab_size) or virtual index (>= vocab_size) for OOV tokensAttributionTargets (Container Class)
High-level encapsulation of attribution target specifications with four construction modes:
Mode 1: Automatic (Salient Logits) - the current default option
Automatically selects the minimal set of top logits whose cumulative probability exceeds the threshold.
Mode 2: Tensor-based (Explicit Token IDs) 1
Directly specify vocabulary indices; probabilities and vectors computed automatically.
Mode 3: Sequence of Strings - supports any sequence type (list, tuple, etc.)
Token strings are tokenized and probabilities/vectors computed automatically. Sequences must contain only strings.
Mode 4: Sequence of Custom Targets - this approach supports the arbitrary logit attribution proposed in PR #44
Fully-specified custom targets with arbitrary strings/functions/transformations. Uses virtual indices for OOV tokens. Sequences must contain only TargetSpec elements (CustomTarget or tuple[str, float, Tensor]).
Important: Sequences must be homogeneous - all strings OR all TargetSpec, not mixed. The type is determined by inspecting the first element.
Key Properties:
Virtual Indices:
The approach used for representing arbitrary strings/functions not in the tokenizer's vocabulary using synthetic indices
>= vocab_size. This enables attribution for:"func(token)")Encoding:
virtual_idx = vocab_size + position_in_listNote we don't actually append these tokens to the tokenizer in this initial implementation.
2. Graph Class Integration (
circuit_tracer/graph.py)Multiple construction patterns supported:
Graphto enable reconstruction oftoken_strwhen only the original tensor of token_ids/vocab_indices is provided (complexity additional object weight not justified by the marginal utility)New helper properties:
vocab_indices- List of all vocabulary indices (including virtual)has_virtual_indices- Boolean check using stored vocab_size (deterministic)logit_token_ids- Tensor of real vocab indices only (raises if virtual indices present)AttributionTargetsuses token_ids since disambiguation shouldn't be necessary in that context 2.vocab_size- Stored vocab size for reliable virtual index detectionlogit_tokens(deprecated) - Alias fortoken_idsfor backward compatibility 3Storage model:
Graph stores more primitive objects only:
logit_targets(list[NamedTuple]),logit_probabilities(Tensor),vocab_size(int)AttributionTargetscontainer not persisted (avoids serializing tokenizer and no longer required components etc.)The Graph integration uses structured LogitTarget records instead of raw tensors:
Before:
logit_targets:torch.Tensorof shape(k,)containing vocab indiceslogit_probabilities:torch.Tensorof shape(k,)After:
logit_targets:list[LogitTarget]- data transfer objects withtoken_strandvocab_idxlogit_probabilities:torch.Tensorof shape(k,)(unchanged)vocab_size:int(new - enables deterministic virtual index detection)The performance impact of the above change should be negligible for typical use cases:
The overhead might be noticable in edge cases like with k>1000 logits or loading thousands of graphs but I don't foresee those usage patterns.
3. Attribute Function Integration (attribute.py)
API Signature Change:
Key improvements:
Moved salient logit computation from standalone
compute_salient_logits()function intoAttributionTargets._from_salient()as one supported construction path - better encapsulationUnified target processing:
4. Benefits of the API Enhancement
Encapsulation & Single Responsibility
AttributionTargetshandles target definition;attribute()handles graph computationType Safety & Validation
All four construction modes include comprehensive validation with clear error messages.
Future-Proof Interface
The container pattern accommodates future enhancements without breaking changes:
AttributionTargetsBackward Compatibility Strategy
Graphclass maintains compatibility while supporting new patterns (though we have more breaking-change latitude at the moment, this is still nice):The
logit_tokensproperty deprecated (not removed) with clear migration path 3:Developer Experience Improvements
Discoverability: Properties self-document available access patterns
Flexibility: Same workflow supports multiple input styles
Error Prevention: Virtual index checks prevent runtime errors
5. Testing Coverage (test_attribution_targets.py)
Tried to keep testing parameterized and limited/lightweight. Claude kept wanting to add more tests so I've pared it back substantially but can prune/refactor more if desired:
6. Migration Path
For existing code (minimal changes required) 3:
For new code (use new patterns):
7. Multi-backend Compatibility 🔀
This PR has been adjusted to accommodate the multi-backend architecture changes introduced in commit
9317b2a. The adjustments ensure thatAttributionTargets,Graphintegration, and the downstream attribute logic behave correctly across the different model backends supported by circuit-tracer (TransformerLens, NNsight). Tests and CI selections were validated against the updated multi-backend behavior where feasible.Current test status when running on a 24GB VRAM GPU used for local validation:
Summary
This PR provides a robust, extensible foundation for attribution target specification while maintaining backward compatibility. The
AttributionTargetsencapsulation enables future enhancements (custom probability methods, alternative target types, etc.) without breaking existing workflows. The four construction modes balance convenience (auto-select) with control (explicit specification via tensors or sequences) while the use of virtual indices supports attribution for arbitrary string representations beyond the tokenizer vocabulary.Thanks you so much to the circuit-tracer authors for making this immensely valuable contribution to the open-source interpretability ecosystem!
While experimental, I see circuit-tracer as a foundationally important component enabling future open-source ML interpretability work. I'm using it extensively in a downstream analysis framework I'm building and couldn't appreciate your work more! I should also thank Claude 4 for its help (somewhat obviously) creating the scaffolding of this PR and for generating/refining the unit tests.
Footnotes
It's worth noting I've used this approach extensively in a fork of circuit-tracer I maintain for an AI world model analysis framework I'm building (fully open-source but pre-MVP so no official links to share yet) ↩
Let me know if you think we should use the same logit_token_ids property for both (clarity >> brevity principle) ↩
Let me know if you think we should retain this but mark it deprecated, not deprecate it and keep it, or remove it given our breaking changes flexibility we retain. I thought deprecating was safest approach to start present in the PR initially at least ↩ ↩2 ↩3
Transitively thanking the authors of circuit-tracer again for my ability to thank Claude, they really deserve the attribution ↩