Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions circuit_tracer/replacement_model/_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from collections.abc import Sequence

import torch


def validate_single_sequence_inputs(
inputs: str | torch.Tensor | list[int],
method_name: str,
) -> None:
"""Raise a clear error when a single-sequence path receives batched inputs."""

if isinstance(inputs, torch.Tensor):
if inputs.ndim > 2 or (inputs.ndim == 2 and inputs.shape[0] != 1):
raise ValueError(
f"{method_name} only supports a single sequence, got tensor input with shape "
f"{tuple(inputs.shape)}. Loop over the batch instead."
)
return

if isinstance(inputs, Sequence) and not isinstance(inputs, str):
if inputs and not all(isinstance(token_id, int) for token_id in inputs):
raise ValueError(
f"{method_name} only supports a single sequence. Expected a single string, "
"a 1D token tensor, or a single list of token ids. Batched list inputs are "
"not supported."
)
5 changes: 5 additions & 0 deletions circuit_tracer/replacement_model/replacement_model_nnsight.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from nnsight import LanguageModel, Envoy, save, CONFIG as NNSIGHT_CONFIG

from circuit_tracer.attribution.context_nnsight import AttributionContext
from circuit_tracer.replacement_model._validation import validate_single_sequence_inputs
from circuit_tracer.transcoder import TranscoderSet
from circuit_tracer.transcoder.cross_layer_transcoder import CrossLayerTranscoder
from circuit_tracer.utils import get_default_device
Expand Down Expand Up @@ -486,6 +487,7 @@ def setup_attribution(self, inputs: str | torch.Tensor):
inputs (str): the inputs to attribute - hard coded to be a single string (no
batching) for now
"""
validate_single_sequence_inputs(inputs, "setup_attribution")

if isinstance(inputs, str):
tokens = self.ensure_tokenized(inputs)
Expand Down Expand Up @@ -555,6 +557,7 @@ def setup_intervention_with_freeze(
Returns:
tuple[torch.Tensor, list[Callable]]: The freeze hooks needed to run the desired intervention.
"""
validate_single_sequence_inputs(inputs, "setup_intervention_with_freeze")

def get_locs_to_freeze():
# this needs to go in a function that is called only in a trace context! otherwise you can't get the .source twice
Expand Down Expand Up @@ -771,6 +774,7 @@ def feature_intervention(
constrained_layers is not set), saving time. Activations are not returned.
Defaults to True.
"""
validate_single_sequence_inputs(inputs, "feature_intervention")
activation_matrix, activation_fn = self.get_activation_fn(
apply_activation_function=apply_activation_function, sparse=sparse
)
Expand Down Expand Up @@ -880,6 +884,7 @@ def feature_intervention_generate(
constrained_layers is not set), saving time. Returns None for activations.
Defaults to True.
"""
validate_single_sequence_inputs(inputs, "feature_intervention_generate")

# remove verbose kwarg, which is valid for TL models but not NNsight ones.
kwargs.pop("verbose", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from transformer_lens.hook_points import HookPoint

from circuit_tracer.attribution.context_transformerlens import AttributionContext
from circuit_tracer.replacement_model._validation import validate_single_sequence_inputs
from circuit_tracer.transcoder import TranscoderSet
from circuit_tracer.transcoder.cross_layer_transcoder import CrossLayerTranscoder
from circuit_tracer.utils import get_default_device
Expand Down Expand Up @@ -430,6 +431,7 @@ def setup_attribution(self, inputs: str | torch.Tensor):
inputs (str): the inputs to attribute - hard coded to be a single string (no
batching) for now
"""
validate_single_sequence_inputs(inputs, "setup_attribution")

if isinstance(inputs, str):
tokens = self.ensure_tokenized(inputs)
Expand Down Expand Up @@ -488,6 +490,7 @@ def setup_intervention_with_freeze(
Returns:
list[tuple[str, Callable]]: The freeze hooks needed to run the desired intervention.
"""
validate_single_sequence_inputs(inputs, "setup_intervention_with_freeze")

hookpoints_to_freeze = ["hook_pattern"]
if constrained_layers:
Expand Down Expand Up @@ -773,6 +776,7 @@ def feature_intervention(
constrained_layers is not set), saving time. Activations are not returned.
Defaults to True.
"""
validate_single_sequence_inputs(inputs, "feature_intervention")

hooks, _, activation_cache = self._get_feature_intervention_hooks(
inputs,
Expand Down Expand Up @@ -851,6 +855,7 @@ def feature_intervention_generate(
constrained_layers is not set), saving time. Returns None for activations.
Defaults to True.
"""
validate_single_sequence_inputs(inputs, "feature_intervention_generate")

feature_intervention_hook_output = self._get_feature_intervention_hooks(
inputs,
Expand Down
138 changes: 138 additions & 0 deletions tests/test_replacement_model_input_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import gc

import numpy as np
import pytest
import torch
import torch.nn as nn
from transformer_lens import HookedTransformerConfig

from circuit_tracer import ReplacementModel
from circuit_tracer.transcoder import SingleLayerTranscoder, TranscoderSet
from circuit_tracer.transcoder.activation_functions import TopK


@pytest.fixture(autouse=True)
def cleanup_cuda():
yield
torch.cuda.empty_cache()
gc.collect()


def load_dummy_llama_replacement_model():
cfg = HookedTransformerConfig.from_dict(
{
"n_layers": 2,
"d_model": 32,
"n_ctx": 32,
"d_head": 8,
"model_name": "Llama-3.2-1B",
"n_heads": 4,
"d_mlp": 64,
"act_fn": "silu",
"d_vocab": 128,
"eps": 1e-05,
"use_attn_result": False,
"use_attn_scale": True,
"attn_scale": np.float64(8.0),
"use_split_qkv_input": False,
"use_hook_mlp_in": False,
"use_attn_in": False,
"use_local_attn": False,
"ungroup_grouped_query_attention": False,
"original_architecture": "LlamaForCausalLM",
"from_checkpoint": False,
"checkpoint_index": None,
"checkpoint_label_type": None,
"checkpoint_value": None,
"tokenizer_name": "gpt2",
"window_size": None,
"attn_types": None,
"init_mode": "gpt2",
"normalization_type": "RMSPre",
"device": "cpu",
"n_devices": 1,
"attention_dir": "causal",
"attn_only": False,
"seed": 42,
"initializer_range": np.float64(0.02),
"init_weights": True,
"scale_attn_by_inverse_layer_idx": False,
"positional_embedding_type": "rotary",
"final_rms": True,
"d_vocab_out": 128,
"parallel_attn_mlp": False,
"rotary_dim": 8,
"n_params": 123456,
"use_hook_tokens": False,
"gated_mlp": True,
"default_prepend_bos": True,
"dtype": torch.float32,
"tokenizer_prepends_bos": True,
"n_key_value_heads": 4,
"post_embedding_ln": False,
"rotary_base": 500000.0,
"trust_remote_code": False,
"rotary_adjacent_pairs": False,
"load_in_4bit": False,
"num_experts": None,
"experts_per_token": None,
"relative_attention_max_distance": None,
"relative_attention_num_buckets": None,
"decoder_start_token_id": None,
"tie_word_embeddings": False,
"use_normalization_before_and_after": False,
"attn_scores_soft_cap": -1.0,
"output_logits_soft_cap": -1.0,
"use_NTK_by_parts_rope": True,
"NTK_by_parts_low_freq_factor": 1.0,
"NTK_by_parts_high_freq_factor": 4.0,
"NTK_by_parts_factor": 32.0,
}
)

transcoders = {
layer_idx: SingleLayerTranscoder(
cfg.d_model, cfg.d_model * 2, TopK(8), layer_idx, skip_connection=True
)
for layer_idx in range(cfg.n_layers)
}
for transcoder in transcoders.values():
for _, param in transcoder.named_parameters():
nn.init.uniform_(param, a=-0.1, b=0.1)

return ReplacementModel.from_config(
cfg,
TranscoderSet(
transcoders,
feature_input_hook="mlp.hook_in",
feature_output_hook="mlp.hook_out",
),
)


@pytest.mark.parametrize(
"inputs",
[
torch.tensor([[1, 2, 3, 4], [1, 5, 6, 7]], dtype=torch.long),
["short prompt", "a longer prompt"],
],
)
def test_feature_intervention_rejects_batched_inputs(inputs):
model = load_dummy_llama_replacement_model()

with pytest.raises(ValueError, match="only supports a single sequence"):
model.feature_intervention(inputs, [(0, 1, 0, 0.0)]) # type: ignore[arg-type]


@pytest.mark.parametrize(
"inputs",
[
torch.tensor([[1, 2, 3, 4], [1, 5, 6, 7]], dtype=torch.long),
["short prompt", "a longer prompt"],
],
)
def test_feature_intervention_generate_rejects_batched_inputs(inputs):
model = load_dummy_llama_replacement_model()

with pytest.raises(ValueError, match="only supports a single sequence"):
model.feature_intervention_generate(inputs, [(0, slice(1, None), 0, 0.0)]) # type: ignore[arg-type]