Skip to content

Add transformers v5.0.0 compatibility#79

Open
speediedan wants to merge 29 commits intodecoderesearch:mainfrom
speediedan:v5-transformers-support
Open

Add transformers v5.0.0 compatibility#79
speediedan wants to merge 29 commits intodecoderesearch:mainfrom
speediedan:v5-transformers-support

Conversation

@speediedan
Copy link
Copy Markdown
Contributor

This PR adds support for transformers>=5.0.0 and huggingface_hub>=1.0.0 while maintaining backward compatibility with transformers v4.x. The changes address breaking API changes in both libraries that affected model structure paths, MoE internals, and exception constructors.

Key adaptations necessary:

  • Updated model path structure for Gemma3ForConditionalGeneration
  • Updated MLP hook paths for v5's new router_scores format in GptOss MoE
  • Updated test fixtures for v5 API changes
  • Runtime detection of transformers version for conditional mappings to enable v4 BC

Key Changes

1. Version Detection (tl_nnsight_mapping.py)

New module-level version detection:

from packaging import version
import transformers

TRANSFORMERS_VERSION = version.parse(transformers.__version__)
TRANSFORMERS_GTE_5_0_0 = TRANSFORMERS_VERSION >= version.parse("5.0.0")

This constant enables conditional logic throughout the codebase without repeated version checks.

2. Gemma3ForConditionalGeneration Path Updates

Breaking Change in transformers v5:
The multimodal Gemma3 model structure changed from:

language_model.layers[{layer}]  # v4

to:

model.language_model.layers[{layer}]  # v5

Solution:

_gemma3_cond_prefix = "model.language_model" if TRANSFORMERS_GTE_5_0_0 else "language_model"

All paths in gemma_3_conditional_mapping now use this dynamic prefix, including:

  • attention_location_pattern
  • layernorm_scale_location_patterns (7 patterns)
  • pre_logit_location
  • embed_location
  • embed_weight
  • feature_hook_mapping (4 hooks)

3. GptOss MoE MLP Hook Updates

Breaking Change in transformers v5:
The MoE layer's internal processing changed, affecting which NNSight .source reference captures the 3D hidden states output:

Aspect v4 v5
router_scores shape (num_tokens, num_experts) sparse (num_tokens, top_k) compact
3D hidden_states source self_experts_0 hidden_states_reshape_1
MLP forward Experts return 3D directly Explicit reshape from 2D→3D before return

Both versions return (hidden_states, router_scores) tuple, but v5 adds an explicit reshape operation:

v5: https://github.com/huggingface/transformers/blob/08810b1/src/transformers/models/gpt_oss/modular_gpt_oss.py#L136-L142

hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return hidden_states, router_scores

v4: https://github.com/huggingface/transformers/blob/47b0e47/src/transformers/models/gpt_oss/modular_gpt_oss.py#L167-L170

routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
return routed_out, router_scores

Both mlp.hook_out and hook_mlp_out in gpt_oss_mapping now use this dynamic hook path.

_gpt_oss_mlp_hook = (
    "model.layers[{layer}].mlp.source.hidden_states_reshape_1"
    if TRANSFORMERS_GTE_5_0_0
    else "model.layers[{layer}].mlp.source.self_experts_0"
)

4. Test Fixture Updates

4.1 HuggingFace Utils Tests (test_hf_utils.py)

GatedRepoError and RepositoryNotFoundError constructors now require a response parameter in huggingface_hub >=1.3.4.

Before:

mock_download.side_effect = GatedRepoError("User has not accepted terms.")
mock_repo_info.side_effect = RepositoryNotFoundError("Repo not found.")

After:

mock_response = mock.MagicMock()
mock_response.status_code = 403
mock_download.side_effect = GatedRepoError("User has not accepted terms.", response=mock_response)

4.2 GptOss Attribution Tests (test_attributions_gpt_oss_nnsight.py)

transformers v5 MoE operations (torch._grouped_mm) don't support float64 dtype.

https://github.com/huggingface/transformers/blob/08810b1e278938278c50153ee1edfd7a20a759da/src/transformers/integrations/moe.py#L188

Changes:

  • Changed test dtype from float64 to float32 in:
    • load_large_gpt_oss_model_with_dummy_clt()
    • test_large_gpt_oss_model()
  • Added relaxed tolerances for float32 precision in test_large_gpt_oss_model:
    verify_token_and_error_edges(model, graph, act_atol=0.2, act_rtol=1e-2, logit_atol=0.2, logit_rtol=1e-2)
    verify_feature_edges(model, graph, act_atol=0.2, act_rtol=1e-2, logit_atol=0.2, logit_rtol=1e-2)

5. HuggingFace Transfer Mode Environment Variable (hf_utils.py)

The HF_HUB_ENABLE_HF_TRANSFER environment variable was removed in huggingface_hub v1.0+ and replaced with HF_XET_HIGH_PERFORMANCE.

These environment variables control high-performance file transfer modes that require sequential (non-parallel) downloads. The download_hf_uris() function previously checked HF_HUB_ENABLE_HF_TRANSFER to decide between sequential and parallel downloads.

Solution:

from circuit_tracer.utils.tl_nnsight_mapping import TRANSFORMERS_GTE_5_0_0

# In download_hf_uris():
if TRANSFORMERS_GTE_5_0_0:
    use_sequential = os.environ.get("HF_XET_HIGH_PERFORMANCE", "0") == "1"
else:
    use_sequential = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0") == "1"

if use_sequential:
    results = [_download(uri) for uri in uri_list]  # Sequential
else:
    results = thread_map(_download, uri_list, ...)   # Parallel

This ensures:

  • v4: Respects HF_HUB_ENABLE_HF_TRANSFER for hf_transfer compatibility
  • v5: Respects HF_XET_HIGH_PERFORMANCE for hf_xet compatibility
  • Default behavior (no env var set) uses parallel downloads for both versions

6. Dependency Updates (pyproject.toml)

Updated constraints:

[project]
dependencies = [
    ...
    "transformers>=4.56.0,<=5.1.0",  # was <=4.57.3
]

[project.optional-dependencies]
dev = ["pytest>=8.0.0", "pytest-rerunfailures>=14.0", ...]  # added pytest-rerunfailures

Migration Path

For users on transformers v4.x:

No changes required. The version detection ensures v4 paths are used automatically.

For users upgrading to v5:

  1. Update transformers: pip install transformers>=5.0.0
  2. Update huggingface_hub: pip install huggingface_hub>=1.0.0
  3. If using GptOss MoE models with float64, switch to float32

Backward Compatibility

The changes are designed to be fully backward compatible:

  • Version detection selects appropriate paths at import time
  • No API changes for users
  • Existing code continues to work with both v4 and v5

Technical Notes

NNSight .source Attribute Behavior

Both v4 and v5 return (hidden_states, router_scores) tuple from GptOssMLP.forward(). The difference is in how the 3D hidden states are obtained:

  • v4: The experts module returns 3D output directly, which NNSight captures via self_experts_0
  • v5: The MLP explicitly reshapes from 2D→3D before returning, which NNSight captures via hidden_states_reshape_1

Additionally, the router_scores shape changed from sparse (num_tokens, num_experts) in v4 to compact (num_tokens, top_k) in v5.

This change affects which NNSight .source reference captures the 3D hidden states output when tracing the computational graph.

Gemma3 Conditional Model Structure

The multimodal Gemma3 model wraps the language model differently in v5:

  • v4: Direct access via language_model.*
  • v5: Nested access via model.language_model.*

This is consistent with how other multimodal models (like LLaVA) structure their components.


hannamw and others added 14 commits November 5, 2025 16:02
Squashed changes from arbitrary-logit-attribution branch.
…he functionality proposed in PR decoderesearch#44 and provides a unified foundation for future attribution target enhancements.
…ing cases (decoderesearch#51)

* fix: handle single module in offload_modules

Fix TypeError when passing a single module (e.g., CrossLayerTranscoder)
to offload_modules instead of a list. Now properly handles single
modules, lists, and PyTorch container types (ModuleList, ModuleDict,
Sequential).

* minor aesthetic change, slight simplification of logic
…ch sizes to allow running more tests with minimum CI hardware profile
- Fix Gemma3 conditional model prefix (language_model → model.language_model in v5)
- Fix GptOss MoE MLP hook reference for v5 reshape changes
- Add TRANSFORMERS_GTE_5_0_0 version detection constant
- Update hf_utils.py to use HF_XET_HIGH_PERFORMANCE (v5) or HF_HUB_ENABLE_HF_TRANSFER (v4) for sequential downloads
- Cast tokenizer.decode() to str to fix pyright type errors
- Add assertions for optional config fields in tests
- Fix huggingface_hub >=1.3.4 mock construction (require response parameter)
- Change GptOss test dtype to float32 (v5 MoE doesn't support float64)
- Update pyproject.toml to allow transformers <=5.1.0
@speediedan speediedan marked this pull request as ready for review January 31, 2026 20:32
Copilot AI review requested due to automatic review settings January 31, 2026 20:32
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds support for transformers>=5.0.0 and huggingface_hub>=1.0.0 while maintaining backward compatibility with transformers v4.x. The changes address breaking API changes in model structure paths, MoE internals, exception constructors, and environment variables.

Changes:

  • Added version detection using the packaging library to conditionally apply v4 vs v5 logic
  • Updated model path structures for Gemma3ForConditionalGeneration and GptOss MoE models
  • Introduced new AttributionTargets and LogitTarget classes to support flexible target specifications including arbitrary string tokens via virtual indices
  • Updated test fixtures for huggingface_hub v1.0+ API changes and transformers v5 MoE float32 requirements
  • Modified environment variable handling for high-performance transfer modes (HF_HUB_ENABLE_HF_TRANSFER → HF_XET_HIGH_PERFORMANCE)

Reviewed changes

Copilot reviewed 29 out of 29 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
circuit_tracer/utils/tl_nnsight_mapping.py Added version detection and conditional model path mappings for v4/v5 compatibility
circuit_tracer/utils/hf_utils.py Updated environment variable handling for high-performance transfer modes
circuit_tracer/attribution/targets.py New module with AttributionTargets class for flexible target specification
circuit_tracer/graph.py Updated Graph class to use LogitTarget with backward compatibility for tensor format
circuit_tracer/attribution/attribute*.py Updated attribution functions to use new AttributionTargets API
circuit_tracer/utils/create_graph_files.py Updated to handle LogitTarget unpacking for graph file generation
tests/utils/test_hf_utils.py Updated mock error constructors to include required response parameter
tests/test_attributions_gpt_oss_nnsight.py Changed dtype from float64 to float32 with relaxed tolerances for v5 MoE compatibility
tests/test_attributions_gemma*.py Added batch_size parameter to reduce memory usage and added test markers
pyproject.toml Updated version constraints for transformers and huggingface_hub, added pytest-rerunfailures

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@speediedan
Copy link
Copy Markdown
Contributor Author

@hannamw Validating transformers v5.0.0 support was important for the world model analysis framework I'm building so I figured I should do my part helping out with circuit-tracer support for it.

I've based this v5 transformers compatibility PR on the AttributionTargets encapsulation proposal I submitted to ensure it was compatible but I can submit a separate PR for this if you prefer.

… review feedback, includes comment and mark changes that will be separated into a separate PR
…tion validation approach, remove attribution_targets ref from `Graph`, remove unnecessary vocab_indices properties, convert multi-token string detection handling from warning to error
- refined initial description and overview of the attribution targets API
- updated colab notebook setup to allow testing the attribution-targets branch before merging
- improved formatting of the attribution target, token mapping outputs and top feature comparison tables
- verified the demo notebook is working as expected in colab
…an_cuda` context manager to allow simple orchestration of sequential backend parity tests. Useful for VRAM constrained environments or to enable parity testing of larger models that may not fit on GPU simultaneously.
…cted the `CustomTarget` examples and helper functions discussion to a distinct section. added a torch.Tensor version of the `Sequence[str]` example for completeness.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants