Skip to content

Attribution Targets Encapsulation#52

Merged
hannamw merged 20 commits intodecoderesearch:mainfrom
speediedan:attribution-targets
Feb 23, 2026
Merged

Attribution Targets Encapsulation#52
hannamw merged 20 commits intodecoderesearch:mainfrom
speediedan:attribution-targets

Conversation

@speediedan
Copy link
Copy Markdown
Contributor

@speediedan speediedan commented Nov 7, 2025

Overview

This PR proposes an enhancement (AttributionTargets) that encapsulates the functionality proposed in PR #44 and provides a unified foundation for future attribution target extensions.

Building upon the functionality proposed in PR #44, it encapsulates attribution target definition/resolution while extending configuration flexibility. The defined interface should robustly accommodate future attribution target features and shield users from necessary internal refactors.

Key constructs:

  • AttributionTargets: High-level container that encapsulates attribution target specifications, probabilities, and unembedding vectors, offering several different target resolution and construction methods.
  • LogitTarget: Low-level data transfer object, a lightweight record storing token metadata (token_str, vocab_idx)
  • vocab_idx/vocab_indices: A generalization of token_ids that can also include virtual indices (synthetic indices representing OOV strs using indices >= vocab_size).
    • This property is used to support logit attribution using arbitrary string or function/transformations thereof while continuing to allow existing token_id tensor driven filtering or other access patterns using views of the same underlying data structure.
    • The user can use the token_ids property to get a tensor of token_ids in the tokenizer's vocab space, vocab_indices to get both the arbitrary string virtual indices and resolved token_ids, tokens to get str version of all attribution targets etc.

Key Changes

1. AttributionTargets (targets.py)

New module introducing two key constructs:

LogitTarget (NamedTuple)

A lightweight data transfer object for storing token metadata:

LogitTarget(token_str: str, vocab_idx: int)
  • token_str: Human-readable token representation (decoded from vocabulary or arbitrary string)
  • vocab_idx: Either a real token ID (< vocab_size) or virtual index (>= vocab_size) for OOV tokens

AttributionTargets (Container Class)

High-level encapsulation of attribution target specifications with four construction modes:

Mode 1: Automatic (Salient Logits) - the current default option

targets = AttributionTargets(
    attribution_targets=None,  # Auto-select mode
    logits=logits,
    unembed_proj=unembed,
    tokenizer=tokenizer,
    max_n_logits=10,
    desired_logit_prob=0.95
)

Automatically selects the minimal set of top logits whose cumulative probability exceeds the threshold.

Mode 2: Tensor-based (Explicit Token IDs) 1

targets = AttributionTargets(
    attribution_targets=torch.tensor([262, 290, 345]),
    logits=logits,
    unembed_proj=unembed,
    tokenizer=tokenizer
)

Directly specify vocabulary indices; probabilities and vectors computed automatically.

Mode 3: Sequence of Strings - supports any sequence type (list, tuple, etc.)

targets = AttributionTargets(
    attribution_targets=["hello", "world", "test"],  # or tuple("hello", "world", "test")
    logits=logits,
    unembed_proj=unembed,
    tokenizer=tokenizer
)

Token strings are tokenized and probabilities/vectors computed automatically. Sequences must contain only strings.

Mode 4: Sequence of Custom Targets - this approach supports the arbitrary logit attribution proposed in PR #44

targets = AttributionTargets(
    attribution_targets=[
        ("func(x)", 0.1, custom_vec),           # Raw tuple
        CustomTarget("g(y)", 0.2, another_vec), # Named tuple
    ],
    logits=logits,
    unembed_proj=unembed,
    tokenizer=tokenizer
)

Fully-specified custom targets with arbitrary strings/functions/transformations. Uses virtual indices for OOV tokens. Sequences must contain only TargetSpec elements (CustomTarget or tuple[str, float, Tensor]).

Important: Sequences must be homogeneous - all strings OR all TargetSpec, not mixed. The type is determined by inspecting the first element.

Key Properties:

# Access patterns
targets.tokens                    # List[str] - token strings
targets.vocab_indices             # List[int] - all indices (including virtual)
targets.token_ids                 # Tensor - only real vocab indices (raises if virtual present)
targets.has_virtual_indices       # bool - check for OOV tokens
targets.logit_probabilities       # Tensor - softmax probabilities
targets.logit_vectors             # Tensor - demeaned unembedding vectors

Virtual Indices:
The approach used for representing arbitrary strings/functions not in the tokenizer's vocabulary using synthetic indices >= vocab_size. This enables attribution for:

  • Composed/transformed tokens (e.g., "func(token)")
  • Multi-token aggregations
  • Any string representation useful for analysis

Encoding: virtual_idx = vocab_size + position_in_list
Note we don't actually append these tokens to the tokenizer in this initial implementation.

2. Graph Class Integration (circuit_tracer/graph.py)

Multiple construction patterns supported:

# Option 1: AttributionTargets container
Graph(..., attribution_targets=targets)

# Option 2: LogitTarget list (structured - for creating with known token strings)
Graph(...,
      logit_targets=[LogitTarget("the", 262), LogitTarget("a", 290)],
      logit_probabilities=probs,
      vocab_size=50257)

# Option 3: Tensor
Graph(...,
      logit_targets=torch.tensor([262, 290]),  # token_str will be empty
      logit_probabilities=probs,
      vocab_size=50257)
  • Claude really accelerated my traversal of the design space here, I decided against keeping a tokenizer ref in Graph to enable reconstruction of token_str when only the original tensor of token_ids/vocab_indices is provided (complexity additional object weight not justified by the marginal utility)

New helper properties:

  • vocab_indices - List of all vocabulary indices (including virtual)
  • has_virtual_indices - Boolean check using stored vocab_size (deterministic)
  • logit_token_ids - Tensor of real vocab indices only (raises if virtual indices present)
    • Note I chose to prefix token_ids with logit_token_ids to disambiguate w.r.t. tokens referenced in the Graph context. AttributionTargets uses token_ids since disambiguation shouldn't be necessary in that context 2.
  • vocab_size - Stored vocab size for reliable virtual index detection
  • logit_tokens (deprecated) - Alias for token_ids for backward compatibility 3

Storage model:

  • Graph stores more primitive objects only: logit_targets (list[NamedTuple]), logit_probabilities (Tensor), vocab_size (int)

  • AttributionTargets container not persisted (avoids serializing tokenizer and no longer required components etc.)

  • The Graph integration uses structured LogitTarget records instead of raw tensors:

    Before:

    • logit_targets: torch.Tensor of shape (k,) containing vocab indices
    • logit_probabilities: torch.Tensor of shape (k,)

    After:

    • logit_targets: list[LogitTarget] - data transfer objects with token_str and vocab_idx
    • logit_probabilities: torch.Tensor of shape (k,) (unchanged)
    • vocab_size: int (new - enables deterministic virtual index detection)

The performance impact of the above change should be negligible for typical use cases:

  • Storage increase: <0.05% for example graphs
  • Load time increase: <1% for example graphs
  • Memory overhead: <1KB per graph
  • Benefits outweigh the costs IMHO:
    • Deterministic virtual index detection
    • Better debugging (human-readable token strings)

The overhead might be noticable in edge cases like with k>1000 logits or loading thousands of graphs but I don't foresee those usage patterns.

3. Attribute Function Integration (attribute.py)

API Signature Change:

# Before
attribute(prompt, model, quantity_to_attribute=None, ...)

# After  
attribute(prompt, model, attribution_targets=None, ...)

Key improvements:

  1. Moved salient logit computation from standalone compute_salient_logits() function into AttributionTargets._from_salient() as one supported construction path - better encapsulation

  2. Unified target processing:

# Before: Manual tuple unpacking
if quantity_to_attribute is not None:
    logit_idx, logit_p, logit_vecs = zip(*quantity_to_attribute)
    logit_p = torch.tensor(logit_p)
    logit_vecs = torch.stack(logit_vecs)
else:
    logit_idx, logit_p, logit_vecs = compute_salient_logits(...)

# After: Single construction point
targets = AttributionTargets(
    attribution_targets=attribution_targets,
    logits=ctx.logits[0, -1],
    unembed_proj=model.unembed.W_U,
    tokenizer=model.tokenizer,
    max_n_logits=max_n_logits,
    desired_logit_prob=desired_logit_prob,
)
  1. Cleaner downstream usage:
# Access via properties instead of separate variables
for i in range(0, len(targets), batch_size):
    batch = targets.logit_vectors[i : i + batch_size]
    # ...

influences = compute_partial_influences(
    edge_matrix[:st], 
    targets.logit_probabilities,  # Instead of logit_p
    row_to_node_index[:st]
)

4. Benefits of the API Enhancement

Encapsulation & Single Responsibility

  • Attribution target specification, validation, and computation now unified in one class
  • Clear separation: AttributionTargets handles target definition; attribute() handles graph computation
  • Internal refactors don't affect user code

Type Safety & Validation

# Automatic validation of token indices
targets = AttributionTargets(
    attribution_targets=torch.tensor([999999]),  # Out of vocab
    logits=logits, ...
)
# ValueError: Token indices must be in range [0, vocab_size)

All four construction modes include comprehensive validation with clear error messages.

Future-Proof Interface

The container pattern accommodates future enhancements without breaking changes:

  • New target types (e.g., token ranges, regex patterns)
  • Additional metadata (e.g., custom weights, constraints)
  • Alternative probability computation methods
  • All added as new properties/methods on AttributionTargets

Backward Compatibility Strategy

Graph class maintains compatibility while supporting new patterns (though we have more breaking-change latitude at the moment, this is still nice):

# Old serialized graphs can still load
graph = Graph(..., 
    logit_targets=torch.tensor([...]),  # Legacy format
    logit_probabilities=probs,
    vocab_size=50257
)

# New code uses AttributionTargets container
graph = Graph(..., attribution_targets=targets)

The logit_tokens property deprecated (not removed) with clear migration path 3:

@property
def logit_tokens(self) -> torch.Tensor:
    warnings.warn(
        "logit_tokens deprecated. Use logit_token_ids instead.",
        DeprecationWarning, stacklevel=2
    )
    return self.logit_token_ids

Developer Experience Improvements

Discoverability: Properties self-document available access patterns

targets.  # IDE autocomplete shows: tokens, vocab_indices, token_ids, 
          # has_virtual_indices, logit_probabilities, logit_vectors

Flexibility: Same workflow supports multiple input styles

# Quick exploration - let system choose
graph = attribute(prompt, model)

# Specific analysis - provide exact tokens  
graph = attribute(prompt, model, attribution_targets=torch.tensor([262, 290]))

# String-based attribution
graph = attribute(prompt, model, attribution_targets=["the", "a", "an"])

# Advanced research - custom targets with arbitrary directions
graph = attribute(prompt, model, attribution_targets=[
    ("custom_direction", 0.1, custom_vec),
    CustomTarget("func(x)", 0.2, another_vec)
])

Error Prevention: Virtual index checks prevent runtime errors

if not targets.has_virtual_indices:
    # Safe to use tensor operations
    selected_logits = all_logits[targets.token_ids]
else:
    # Must handle virtual indices separately
    for idx, target in zip(targets.vocab_indices, targets.logit_targets):
        # ... custom handling

5. Testing Coverage (test_attribution_targets.py)

Tried to keep testing parameterized and limited/lightweight. Claude kept wanting to add more tests so I've pared it back substantially but can prune/refactor more if desired:

  • All four construction modes (salient, tensor, Sequence[str], Sequence[TargetSpec])
  • Tuple input validation (ensuring non-list sequences work)
  • Virtual index handling
  • Validation and error cases (including rejection of heterogeneous sequences)
  • Property access patterns
  • Edge cases (empty inputs, out-of-range indices, etc.)

6. Migration Path

For existing code (minimal changes required) 3:

# Old API still works with deprecation warning
graph.logit_tokens  # Works, shows warning

# Recommended update
graph.logit_token_ids  # New property name

For new code (use new patterns):

# Use AttributionTargets container for flexibility
targets = AttributionTargets(...)
graph = attribute(prompt, model, attribution_targets=targets)

# Or let attribute() construct it from strings
graph = attribute(prompt, model, attribution_targets=["hello", "world"])

# Or from custom targets
graph = attribute(prompt, model, attribution_targets=[
    CustomTarget("token", 0.5, vec1),
    ("func(x)", 0.3, vec2)
])

7. Multi-backend Compatibility 🔀

This PR has been adjusted to accommodate the multi-backend architecture changes introduced in commit 9317b2a. The adjustments ensure that AttributionTargets, Graph integration, and the downstream attribute logic behave correctly across the different model backends supported by circuit-tracer (TransformerLens, NNsight). Tests and CI selections were validated against the updated multi-backend behavior where feasible.

Current test status when running on a 24GB VRAM GPU used for local validation:

  • 130 tests passing skipping tests that triggered OOM with 24GB VRAM
  • 43 tests deselected during the run gated on VRAM requirements

Summary

This PR provides a robust, extensible foundation for attribution target specification while maintaining backward compatibility. The AttributionTargets encapsulation enables future enhancements (custom probability methods, alternative target types, etc.) without breaking existing workflows. The four construction modes balance convenience (auto-select) with control (explicit specification via tensors or sequences) while the use of virtual indices supports attribution for arbitrary string representations beyond the tokenizer vocabulary.

Thanks you so much to the circuit-tracer authors for making this immensely valuable contribution to the open-source interpretability ecosystem!

While experimental, I see circuit-tracer as a foundationally important component enabling future open-source ML interpretability work. I'm using it extensively in a downstream analysis framework I'm building and couldn't appreciate your work more! I should also thank Claude 4 for its help (somewhat obviously) creating the scaffolding of this PR and for generating/refining the unit tests.

Footnotes

  1. It's worth noting I've used this approach extensively in a fork of circuit-tracer I maintain for an AI world model analysis framework I'm building (fully open-source but pre-MVP so no official links to share yet)

  2. Let me know if you think we should use the same logit_token_ids property for both (clarity >> brevity principle)

  3. Let me know if you think we should retain this but mark it deprecated, not deprecate it and keep it, or remove it given our breaking changes flexibility we retain. I thought deprecating was safest approach to start present in the PR initially at least 2 3

  4. Transitively thanking the authors of circuit-tracer again for my ability to thank Claude, they really deserve the attribution

hannamw and others added 2 commits November 5, 2025 16:02
Squashed changes from arbitrary-logit-attribution branch.
…he functionality proposed in PR decoderesearch#44 and provides a unified foundation for future attribution target enhancements.
@speediedan speediedan marked this pull request as ready for review November 7, 2025 00:28
Copilot AI review requested due to automatic review settings November 7, 2025 00:28
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors the attribution system to support flexible logit target specification, including arbitrary string tokens and functions thereof that may not be in the model's vocabulary. The changes introduce a new AttributionTargets container class and LogitTarget data structure while maintaining backward compatibility with the legacy tensor format.

Key Changes

  • Introduces AttributionTargets class to handle multiple target specification formats (None for auto-selection, torch.Tensor for specific token IDs, or mixed lists including arbitrary strings)
  • Replaces Graph.logit_tokens (torch.Tensor) with Graph.logit_targets (list[LogitTarget]) to support both real vocabulary indices and virtual indices for out-of-vocabulary tokens
  • Implements virtual index mechanism (vocab_size + position) to uniquely identify arbitrary string tokens not in the model's vocabulary

Reviewed Changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
circuit_tracer/attribution/targets.py New module introducing AttributionTargets container and LogitTarget NamedTuple for flexible target specification
circuit_tracer/graph.py Refactored Graph class to use LogitTarget list instead of tensor, added properties for backward compatibility and virtual index handling
circuit_tracer/attribution/attribute.py Updated attribution function to use AttributionTargets, removed deprecated compute_salient_logits function
circuit_tracer/utils/create_graph_files.py Updated node creation to unpack LogitTarget tuples for token strings and vocabulary indices
tests/test_attribution_targets.py Comprehensive test suite for new AttributionTargets class covering all input formats and edge cases
tests/test_graph.py Added tests for Graph serialization/deserialization with both tensor and LogitTarget formats
tests/test_attributions_gemma.py Updated to use new logit_token_ids property instead of deprecated logit_tokens
tests/test_attribution_clt.py Updated to use new logit_token_ids property
demos/attribute_demo.ipynb Updated documentation to reference logit_targets instead of logit_tokens

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@hannamw
Copy link
Copy Markdown
Collaborator

hannamw commented Nov 13, 2025

Hi, I just wanted to leave a comment and let you know I've seen this PR and the others - thanks for contributing! I have been away for the past week (and will be for the next), but hope to look at this soon.

@speediedan
Copy link
Copy Markdown
Contributor Author

Hi, I just wanted to leave a comment and let you know I've seen this PR and the others - thanks for contributing! I have been away for the past week (and will be for the next), but hope to look at this soon.

No worries at all, thanks for letting me know. Looking forward to discussing then! 🚀

@speediedan
Copy link
Copy Markdown
Contributor Author

@hannamw Any thoughts/feedback on this proposed enhancement? I'm sure you've got an immense amount on your plate, just wanted to ensure it was still in the queue somewhere or if not, if there were amendments that could be made to otherwise allow the proposed functionality.

The world model analysis framework I'm building will depend on the arbitrary logit attribution target specification you seemed to be exploring in #44, which was closely related to enhancements I'd already been using on the circuit-tracer fork for my pre-MVP package. I think the enhanced attribution target definition/resolution interface could be broadly useful as I can foresee an increasing volume and variety of use cases on the horizon for your already immensely valuable package! Thanks again for all your work! 🚀 🎉

…ing cases (decoderesearch#51)

* fix: handle single module in offload_modules

Fix TypeError when passing a single module (e.g., CrossLayerTranscoder)
to offload_modules instead of a list. Now properly handles single
modules, lists, and PyTorch container types (ModuleList, ModuleDict,
Sequential).

* minor aesthetic change, slight simplification of logic
…ch sizes to allow running more tests with minimum CI hardware profile
@hannamw
Copy link
Copy Markdown
Collaborator

hannamw commented Jan 28, 2026

Hi! Sorry to take so long to get to this - I was working on the nnsight backend, and didn't want to merge anything large until then. I noticed that you updated this with nnsight attribution; is this ready for review? If so, I'll give it a look!

@speediedan
Copy link
Copy Markdown
Contributor Author

speediedan commented Jan 28, 2026

Hi! Sorry to take so long to get to this - I was working on the nnsight backend, and didn't want to merge anything large until then. I noticed that you updated this with nnsight attribution; is this ready for review? If so, I'll give it a look!

No worries @hannamw, I saw you had a substantial lift with that work and expected that was the case. Nice work on the nnsight addition and multi-backend architectural refactor btw!

This PR is ready for re-review, thanks!

Copy link
Copy Markdown
Collaborator

@hannamw hannamw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a bunch for this big PR! I really appreciate the work that you've put into it. I added comments throughout; let me know what you think. I also think it'd be great to add a demo notebook showcasing this new feature. Thanks again for implementing this!

model: "NNSightReplacementModel | TransformerLensReplacementModel",
*,
attribution_targets: (
Sequence[tuple[str, float, torch.Tensor] | int | str] | torch.Tensor | None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be better here to just commit to a format! i.e. input either None, a Sequence of strs, or a sequence of fully specified attribution targets. This cuts down on the number of cases / potential confusion.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense! Supporting the mixed-mode input is probably not worth the marginal complexity and potential user confusion. I've dropped the mixed-mode list[int | str | tuple] entirely.

The revised interface for attribution_targets I'm proposing now is:
Sequence[str] | Sequence[TargetSpec] | torch.Tensor | None

  • Sequence[str] — token strings (AttributionTargets handles tokenization, we warn on multi-token strings and take the final token, lmk if you prefer we error instead or only take the final token if a bool flag is set)
  • Sequence[TargetSpec] — fully specified custom targets with explicit token_str, probability, and unembed vectors. I thought it was worth making the target specification a named tuple for clarity. We still accept regular tuples and create named tuples for the user if preferred but keep the signature readable with this alias:
    TargetSpec = CustomTarget | tuple[str, float, torch.Tensor]
  • torch.Tensor — My reasoning in accepting the torch.Tensor input is that for it's convenient for downstream consumers that have pre-tokenized IDs they'd like to attribute without roundtripping through strings. The world model analysis framework I'm building does this, invoking circuit-tracer's attribute via composable operations with pre-tokenized tensor inputs. If you feel strongly about this we can drop it and require users to convert to Sequence[str] or Sequence[TargetSpec] but I think it's a nice-to-have to allow greater breadth of downstream use cases and doesn't add much complexity.
  • None — auto-select salient logits

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds okay to me - I recognize that having multiple input formats could be convenient (even though I would like to cut down on the number of formats). Re: warning vs. erroring, I'd prefer to error here. If anything, I think I'd want to take the first token, but erroring still seems best.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense! Changed our multi-token string handling from warnings.warn() to raise ValueError().

self.logit_targets = attribution_targets.logit_targets
self.logit_probabilities = attribution_targets.logit_probabilities
self.vocab_size = attribution_targets.vocab_size
elif logit_targets is not None and logit_probabilities is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this for backwards compatibility? I don't mind breaking this, as long as this is clearly indicated in the new release. Or perhaps more elegant would be to move this into from_pt

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this was for backward compatibility with serialized graphs. Agreed that from_pt is the right place for this — moved the tensor-to-LogitTarget conversion there. Graph.__init__ now only accepts attribution_targets (the container) or logit_targets as list[LogitTarget] directly, making the constructor cleaner.

Copy link
Copy Markdown
Collaborator

@hannamw hannamw Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! One more request - can we make it so that only one of (attribution_targets or (logit_targets and logit_probabilities) is valid? It would be cleaner to have just one option. I guess I would prefer to get rid of the attribution_targets option and just access its attributes when creating the Graph in attribute, seeing as we don't actually use attribution_targets itself.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense! The AttributionTargets container really is an attribution-phase concern, not something Graph needs to know about. I removed the attribution_targets parameter entirely and made logit_targets and logit_probabilities required positional parameters. attribute_transformerlens.py and attribute_nnsight.py now unpack the relevant attributes from the AttributionTargets instance.

@@ -0,0 +1,384 @@
"""Unit tests for AttributionTargets class."""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to have a test that verifies the correctness of the List[tuple] formatted attribution targets. For example, attribute back from logit(x) - logit(y) (i.e. the difference in the corresponding unembed layers), and verify that, upon intervention, the difference in their logits changes as expected

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One could also attribute back from None and its equivalent token lists / attribution target tuples and check for consistency

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great suggestions — I thought it'd make sense to add these as integrations tests, 2 per backend:

  1. test_custom_target_correctness — constructs a logit(x) - logit(y) direction from the unembed matrix, runs attribution with this as a CustomTarget, then ablates/amplifies the top features and verifying the logit difference changes in the expected direction. Both ablation and amplification produce pronounced effects in the expected directions.

  2. test_attribution_format_consistency — runs attribution with None (auto-select), then with equivalent Sequence[str] and Sequence[CustomTarget] targets, and verifies all three produce identical results: same targets, probabilities, and adjacency matrices.

All four tests use the real gemma-2-2b model. The TL and NNSight variants share backend-agnostic logic via helper functions.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a bunch for adding these! For 1., could we instead do an exact check? If the direction is indeed logit(x) - logit(y), then we should be able to check that the change in logit(x) - logit(y) is precisely what our attribution graph predicts.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! I updated test_custom_target_correctness to adopt the same attribution validation approach used elsewhere in the codebase (e.g. test_attributions_gemma.py::test_gemma_2_2b() verify_feature_edges).

Note that I needed to set logit_atol to be less tight than in verify_feature_edges (1e-4 vs 1e-5) but the other 3 tolerance thresholds are identical. I think this is probably defensible given the combined effects of cancellation (increased low-order bit dependency) and asymmetric accumulation bias (baseline logits of 39 vs 31 create a downward bias for the 39 logit). Empirically, all 50 randomly-sampled features passed at 1e-4 while only ~50% passed at 1e-5. Let me know if you want to try upcasting to float64 or think we need to look other strategies to get tighter tolerances here.

@speediedan
Copy link
Copy Markdown
Contributor Author

speediedan commented Feb 11, 2026

Thanks a bunch for this big PR! I really appreciate the work that you've put into it. I added comments throughout; let me know what you think. I also think it'd be great to add a demo notebook showcasing this new feature. Thanks again for implementing this!

Thanks again for the thorough review @hannamw ! I've addressed all of your feedback in the latest commits and extracted the comment change hunks into a separate PR that I'll open in order to keep the focus of this PR on the attribution targets encapsulation.

Great suggestion on the demo notebook. Is it okay if I add that in a separate PR as soon as I get a chance in the next few days?

Copy link
Copy Markdown
Collaborator

@hannamw hannamw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for these changes! I've left just a few responses in some of the threads, but things are otherwise looking good. I'm looking forward to seeing the demo notebook and trying it out for myself!

…tion validation approach, remove attribution_targets ref from `Graph`, remove unnecessary vocab_indices properties, convert multi-token string detection handling from warning to error
- refined initial description and overview of the attribution targets API
- updated colab notebook setup to allow testing the attribution-targets branch before merging
- improved formatting of the attribution target, token mapping outputs and top feature comparison tables
- verified the demo notebook is working as expected in colab
…an_cuda` context manager to allow simple orchestration of sequential backend parity tests. Useful for VRAM constrained environments or to enable parity testing of larger models that may not fit on GPU simultaneously.
@hannamw
Copy link
Copy Markdown
Collaborator

hannamw commented Feb 21, 2026

Just had a chance to look at this - the notebook looks good! The only change that I would suggest is to reorder it a bit, starting with the simpler types of attribution targets (None, then List[str] and torch.Tensor). I think this will be less confusing to readers, who might not want to immediately jump into diff targets or semantic concept targets.

…cted the `CustomTarget` examples and helper functions discussion to a distinct section. added a torch.Tensor version of the `Sequence[str]` example for completeness.
@speediedan
Copy link
Copy Markdown
Contributor Author

Just had a chance to look at this - the notebook looks good! The only change that I would suggest is to reorder it a bit, starting with the simpler types of attribution targets (None, then List[str] and torch.Tensor). I think this will be less confusing to readers, who might not want to immediately jump into diff targets or semantic concept targets.

Good suggestion! I've restructured the demo to lead with the simpler target modes and extracted the CustomTarget examples and helper functions discussion to a distinct section to bolster clarity. I've improved a little of the exposition and added a torch.Tensor version of the Sequence[str] example as well for completeness.

I've also updated the Open In Colab badge and tutorial banner image to use the main branch, so we should be ready to go if things look good to you! 🚀 🎉

Copy link
Copy Markdown
Collaborator

@hannamw hannamw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all of these changes! I think this is ready to be merged.

@hannamw hannamw merged commit 43933d1 into decoderesearch:main Feb 23, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants