Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 25 additions & 46 deletions circuit_tracer/replacement_model/replacement_model_nnsight.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ def _configure_replacement_model(
self.eval()
self.cfg = convert_nnsight_config_to_transformerlens(self.config)

# special case to zero out <bos><start_of_turn>user\n for gemmascope 2 (-it) transcoders
gemma_3_it = "gemma-3" in self.cfg.model_name and self.cfg.model_name.endswith("-it")
self.zero_positions = slice(0, 4) if gemma_3_it else slice(0, 1)

transcoder_set.to(self.device, self.dtype)
self.transcoders = transcoder_set
self.skip_transcoder = transcoder_set.skip_connection
Expand Down Expand Up @@ -300,26 +304,6 @@ def fetch_activations(
activation_layers: Iterator[int] | None = None,
):
# special case to zero out <bos><start_of_turn>user\n for gemmascope 2 (-it) transcoders
gemma_3_it = "gemma-3" in self.cfg.model_name and self.cfg.model_name.endswith("-it")
overlap = 0
if gemma_3_it:
input_ids = self.input
ignore_prefix = torch.tensor(
[2, 105, 2364, 107], dtype=input_ids.dtype, device=input_ids.device
)
min_len = min(len(input_ids), len(ignore_prefix))
if min_len == 0:
overlap = 0
else:
# Compare the overlapping portion
matches = input_ids[:min_len] == ignore_prefix[:min_len]

# Find the first False (mismatch)
if matches.all():
overlap = min_len
else:
overlap = matches.to(torch.int).argmin().item()

layers = range(self.cfg.n_layers) if activation_layers is None else activation_layers
for layer in layers:
feature_input_loc = self.get_feature_input_loc(layer)
Expand All @@ -334,9 +318,7 @@ def fetch_activations(
)

if not (append and len(activation_matrix[layer]) > 0): # type:ignore
transcoder_acts[0] = 0
if gemma_3_it:
transcoder_acts[:overlap] = 0
transcoder_acts[self.zero_positions] = 0

if sparse:
transcoder_acts = transcoder_acts.to_sparse()
Expand Down Expand Up @@ -454,9 +436,24 @@ def ensure_tokenized(self, prompt: str | torch.Tensor | list[int]) -> torch.Tens
if tokens.ndim > 1:
raise ValueError(f"Tensor must be 1-D, got shape {tokens.shape}")

tokens = tokens.to(self.device)

gemma_3_it = "gemma-3" in self.cfg.model_name and self.cfg.model_name.endswith("-it")
if gemma_3_it:
ignore_prefix = torch.tensor(
[2, 105, 2364, 107], dtype=tokens.dtype, device=tokens.device
)
tokenization_error = (
"Input tokens should start with <bos><start_of_turn>user\n, but got {tokens}"
)
assert tokens.size(0) >= 4 and torch.all(tokens[:4] == ignore_prefix), (
tokenization_error.format(tokens=self.tokenizer.decode(tokens.cpu().tolist()))
)
return tokens

# Check if a special token is already present at the beginning
if tokens[0] in self.tokenizer.all_special_ids:
return tokens.to(self.device)
return tokens

# Prepend a special token to avoid artifacts at position 0
candidate_bos_token_ids = [
Expand Down Expand Up @@ -516,32 +513,14 @@ def setup_attribution(self, inputs: str | torch.Tensor):
mlp_out_cache = save(torch.cat(mlp_out_cache, dim=0)) # type: ignore
logits = save(self.output.logits)

# special case to zero out <bos><start_of_turn>user\n for gemmascope 2 (-it) transcoders
gemma_3_it = "gemma-3" in self.cfg.model_name and self.cfg.model_name.endswith("-it")
zero_positions = slice(0, 1)
if gemma_3_it:
ignore_prefix = torch.tensor(
[2, 105, 2364, 107], dtype=tokens.dtype, device=tokens.device
)
min_len = min(len(tokens), len(ignore_prefix))
if min_len == 0:
zero_positions = slice(0, 0)
else:
# Compare the overlapping portion
matches = tokens[:min_len] == ignore_prefix[:min_len]

# Find the first False (mismatch)
if matches.all():
zero_positions = slice(0, min_len)
else:
zero_positions = slice(0, matches.to(torch.int).argmin().item())

attribution_data = transcoders.compute_attribution_components(mlp_in_cache, zero_positions) # type: ignore
attribution_data = transcoders.compute_attribution_components(
mlp_in_cache, self.zero_positions
) # type: ignore

# Compute error vectors
error_vectors = mlp_out_cache - attribution_data["reconstruction"]

error_vectors[:, zero_positions] = 0
error_vectors[:, self.zero_positions] = 0
token_vectors = self.embed_weight[ # type: ignore
tokens
].detach() # (n_pos, d_model) # type: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ def _configure_replacement_model(self, transcoder_set: TranscoderSet | CrossLaye
self.backend = "transformerlens"
transcoder_set.to(self.cfg.device, self.cfg.dtype)

# special case to zero out <bos><start_of_turn>user\n for gemmascope 2 (-it) transcoders
gemma_3_it = "gemma-3" in self.cfg.model_name and self.cfg.model_name.endswith("-it")
self.zero_positions = slice(0, 4) if gemma_3_it else slice(0, 1)

self.transcoders = transcoder_set
self.feature_input_hook = transcoder_set.feature_input_hook
self.original_feature_output_hook = transcoder_set.feature_output_hook
Expand Down Expand Up @@ -288,7 +292,7 @@ def cache_activations(acts, hook, layer):
)

if not append:
transcoder_acts[0] = 0
transcoder_acts[self.zero_positions] = 0

if sparse:
transcoder_acts = transcoder_acts.to_sparse()
Expand Down Expand Up @@ -380,9 +384,24 @@ def ensure_tokenized(self, prompt: str | torch.Tensor | list[int]) -> torch.Tens
if tokens.ndim > 1:
raise ValueError(f"Tensor must be 1-D, got shape {tokens.shape}")

tokens = tokens.to(self.cfg.device)

gemma_3_it = "gemma-3" in self.cfg.model_name and self.cfg.model_name.endswith("-it")
if gemma_3_it:
ignore_prefix = torch.tensor(
[2, 105, 2364, 107], dtype=tokens.dtype, device=tokens.device
)
tokenization_error = (
"Input tokens should start with <bos><start_of_turn>user\n, but got {tokens}"
)
assert tokens.size(0) >= 4 and torch.all(tokens[:4] == ignore_prefix), (
tokenization_error.format(tokens=self.tokenizer.decode(tokens.cpu().tolist())) # type: ignore
)
return tokens

# Check if a special token is already present at the beginning
if tokens[0] in self.tokenizer.all_special_ids: # type: ignore
return tokens.to(self.cfg.device)
return tokens

# Prepend a special token to avoid artifacts at position 0
candidate_bos_token_ids = [
Expand Down Expand Up @@ -432,12 +451,14 @@ def setup_attribution(self, inputs: str | torch.Tensor):
mlp_in_cache = torch.cat(list(mlp_in_cache.values()), dim=0)
mlp_out_cache = torch.cat(list(mlp_out_cache.values()), dim=0)

attribution_data = self.transcoders.compute_attribution_components(mlp_in_cache)
attribution_data = self.transcoders.compute_attribution_components(
mlp_in_cache, self.zero_positions
)

# Compute error vectors
error_vectors = mlp_out_cache - attribution_data["reconstruction"]

error_vectors[:, 0] = 0
error_vectors[:, self.zero_positions] = 0
token_vectors = self.W_E[tokens].detach() # (n_pos, d_model)

return AttributionContext(
Expand Down