diff --git a/circuit_tracer/replacement_model/replacement_model_nnsight.py b/circuit_tracer/replacement_model/replacement_model_nnsight.py index 964c856..bbea150 100644 --- a/circuit_tracer/replacement_model/replacement_model_nnsight.py +++ b/circuit_tracer/replacement_model/replacement_model_nnsight.py @@ -214,6 +214,10 @@ def _configure_replacement_model( self.eval() self.cfg = convert_nnsight_config_to_transformerlens(self.config) + # special case to zero out user\n for gemmascope 2 (-it) transcoders + gemma_3_it = "gemma-3" in self.cfg.model_name and self.cfg.model_name.endswith("-it") + self.zero_positions = slice(0, 4) if gemma_3_it else slice(0, 1) + transcoder_set.to(self.device, self.dtype) self.transcoders = transcoder_set self.skip_transcoder = transcoder_set.skip_connection @@ -300,26 +304,6 @@ def fetch_activations( activation_layers: Iterator[int] | None = None, ): # special case to zero out user\n for gemmascope 2 (-it) transcoders - gemma_3_it = "gemma-3" in self.cfg.model_name and self.cfg.model_name.endswith("-it") - overlap = 0 - if gemma_3_it: - input_ids = self.input - ignore_prefix = torch.tensor( - [2, 105, 2364, 107], dtype=input_ids.dtype, device=input_ids.device - ) - min_len = min(len(input_ids), len(ignore_prefix)) - if min_len == 0: - overlap = 0 - else: - # Compare the overlapping portion - matches = input_ids[:min_len] == ignore_prefix[:min_len] - - # Find the first False (mismatch) - if matches.all(): - overlap = min_len - else: - overlap = matches.to(torch.int).argmin().item() - layers = range(self.cfg.n_layers) if activation_layers is None else activation_layers for layer in layers: feature_input_loc = self.get_feature_input_loc(layer) @@ -334,9 +318,7 @@ def fetch_activations( ) if not (append and len(activation_matrix[layer]) > 0): # type:ignore - transcoder_acts[0] = 0 - if gemma_3_it: - transcoder_acts[:overlap] = 0 + transcoder_acts[self.zero_positions] = 0 if sparse: transcoder_acts = transcoder_acts.to_sparse() @@ -454,9 +436,24 @@ def ensure_tokenized(self, prompt: str | torch.Tensor | list[int]) -> torch.Tens if tokens.ndim > 1: raise ValueError(f"Tensor must be 1-D, got shape {tokens.shape}") + tokens = tokens.to(self.device) + + gemma_3_it = "gemma-3" in self.cfg.model_name and self.cfg.model_name.endswith("-it") + if gemma_3_it: + ignore_prefix = torch.tensor( + [2, 105, 2364, 107], dtype=tokens.dtype, device=tokens.device + ) + tokenization_error = ( + "Input tokens should start with user\n, but got {tokens}" + ) + assert tokens.size(0) >= 4 and torch.all(tokens[:4] == ignore_prefix), ( + tokenization_error.format(tokens=self.tokenizer.decode(tokens.cpu().tolist())) + ) + return tokens + # Check if a special token is already present at the beginning if tokens[0] in self.tokenizer.all_special_ids: - return tokens.to(self.device) + return tokens # Prepend a special token to avoid artifacts at position 0 candidate_bos_token_ids = [ @@ -516,32 +513,14 @@ def setup_attribution(self, inputs: str | torch.Tensor): mlp_out_cache = save(torch.cat(mlp_out_cache, dim=0)) # type: ignore logits = save(self.output.logits) - # special case to zero out user\n for gemmascope 2 (-it) transcoders - gemma_3_it = "gemma-3" in self.cfg.model_name and self.cfg.model_name.endswith("-it") - zero_positions = slice(0, 1) - if gemma_3_it: - ignore_prefix = torch.tensor( - [2, 105, 2364, 107], dtype=tokens.dtype, device=tokens.device - ) - min_len = min(len(tokens), len(ignore_prefix)) - if min_len == 0: - zero_positions = slice(0, 0) - else: - # Compare the overlapping portion - matches = tokens[:min_len] == ignore_prefix[:min_len] - - # Find the first False (mismatch) - if matches.all(): - zero_positions = slice(0, min_len) - else: - zero_positions = slice(0, matches.to(torch.int).argmin().item()) - - attribution_data = transcoders.compute_attribution_components(mlp_in_cache, zero_positions) # type: ignore + attribution_data = transcoders.compute_attribution_components( + mlp_in_cache, self.zero_positions + ) # type: ignore # Compute error vectors error_vectors = mlp_out_cache - attribution_data["reconstruction"] - error_vectors[:, zero_positions] = 0 + error_vectors[:, self.zero_positions] = 0 token_vectors = self.embed_weight[ # type: ignore tokens ].detach() # (n_pos, d_model) # type: ignore diff --git a/circuit_tracer/replacement_model/replacement_model_transformerlens.py b/circuit_tracer/replacement_model/replacement_model_transformerlens.py index 0984a3e..4160e76 100644 --- a/circuit_tracer/replacement_model/replacement_model_transformerlens.py +++ b/circuit_tracer/replacement_model/replacement_model_transformerlens.py @@ -166,6 +166,10 @@ def _configure_replacement_model(self, transcoder_set: TranscoderSet | CrossLaye self.backend = "transformerlens" transcoder_set.to(self.cfg.device, self.cfg.dtype) + # special case to zero out user\n for gemmascope 2 (-it) transcoders + gemma_3_it = "gemma-3" in self.cfg.model_name and self.cfg.model_name.endswith("-it") + self.zero_positions = slice(0, 4) if gemma_3_it else slice(0, 1) + self.transcoders = transcoder_set self.feature_input_hook = transcoder_set.feature_input_hook self.original_feature_output_hook = transcoder_set.feature_output_hook @@ -288,7 +292,7 @@ def cache_activations(acts, hook, layer): ) if not append: - transcoder_acts[0] = 0 + transcoder_acts[self.zero_positions] = 0 if sparse: transcoder_acts = transcoder_acts.to_sparse() @@ -380,9 +384,24 @@ def ensure_tokenized(self, prompt: str | torch.Tensor | list[int]) -> torch.Tens if tokens.ndim > 1: raise ValueError(f"Tensor must be 1-D, got shape {tokens.shape}") + tokens = tokens.to(self.cfg.device) + + gemma_3_it = "gemma-3" in self.cfg.model_name and self.cfg.model_name.endswith("-it") + if gemma_3_it: + ignore_prefix = torch.tensor( + [2, 105, 2364, 107], dtype=tokens.dtype, device=tokens.device + ) + tokenization_error = ( + "Input tokens should start with user\n, but got {tokens}" + ) + assert tokens.size(0) >= 4 and torch.all(tokens[:4] == ignore_prefix), ( + tokenization_error.format(tokens=self.tokenizer.decode(tokens.cpu().tolist())) # type: ignore + ) + return tokens + # Check if a special token is already present at the beginning if tokens[0] in self.tokenizer.all_special_ids: # type: ignore - return tokens.to(self.cfg.device) + return tokens # Prepend a special token to avoid artifacts at position 0 candidate_bos_token_ids = [ @@ -432,12 +451,14 @@ def setup_attribution(self, inputs: str | torch.Tensor): mlp_in_cache = torch.cat(list(mlp_in_cache.values()), dim=0) mlp_out_cache = torch.cat(list(mlp_out_cache.values()), dim=0) - attribution_data = self.transcoders.compute_attribution_components(mlp_in_cache) + attribution_data = self.transcoders.compute_attribution_components( + mlp_in_cache, self.zero_positions + ) # Compute error vectors error_vectors = mlp_out_cache - attribution_data["reconstruction"] - error_vectors[:, 0] = 0 + error_vectors[:, self.zero_positions] = 0 token_vectors = self.W_E[tokens].detach() # (n_pos, d_model) return AttributionContext(