-
Notifications
You must be signed in to change notification settings - Fork 125
Rewrite new simulator to use JSON mode; additional fixes to new simulator #34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 13 commits
8d260b3
67e4f99
208663e
9a6739d
d2b5ef5
4ec7d74
610fdce
9759a15
66dfe51
c7c6eea
62eedb5
58af073
d62da4e
bfbeb57
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
|
|
||
| import asyncio | ||
| import logging | ||
| import json | ||
| from abc import ABC, abstractmethod | ||
| from collections import OrderedDict | ||
| from enum import Enum | ||
|
|
@@ -36,6 +37,9 @@ | |
| VALID_ACTIVATION_TOKENS_ORDERED = list(str(i) for i in range(MAX_NORMALIZED_ACTIVATION + 1)) | ||
| VALID_ACTIVATION_TOKENS = set(VALID_ACTIVATION_TOKENS_ORDERED) | ||
|
|
||
| # Edge Case #3: The chat-based simulator is confused by end token. Replace it with a "not end token" | ||
| END_OF_TEXT_TOKEN = "<|endoftext|>" | ||
| END_OF_TEXT_TOKEN_REPLACEMENT = "<|not_endoftext|>" | ||
|
|
||
| class SimulationType(str, Enum): | ||
| """How to simulate neuron activations. Values correspond to subclasses of NeuronSimulator.""" | ||
|
|
@@ -590,51 +594,183 @@ def _format_record_for_logprob_free_simulation( | |
| activation_record.activations, max_activation=max_activation | ||
| ) | ||
| for i, token in enumerate(activation_record.tokens): | ||
| # Edge Case #3: End tokens confuse the chat-based simulator. Replace end token with "not end token". | ||
| if token.strip() == END_OF_TEXT_TOKEN: | ||
| token = END_OF_TEXT_TOKEN_REPLACEMENT | ||
| # We use a weird unicode character here to make it easier to parse the response (can split on "༗\n"). | ||
| if include_activations: | ||
| response += f"{token}\t{normalized_activations[i]}༗\n" | ||
| else: | ||
| response += f"{token}\t༗\n" | ||
| return response | ||
|
|
||
| def _format_record_for_logprob_free_simulation_json( | ||
| neuron: int, | ||
| explanation: str, | ||
| activation_record: ActivationRecord, | ||
| include_activations: bool = False, | ||
| max_activation: Optional[float] = None, | ||
| ) -> str: | ||
| if include_activations: | ||
| assert max_activation is not None | ||
| assert len(activation_record.tokens) == len( | ||
| activation_record.activations | ||
| ), f"{len(activation_record.tokens)=}, {len(activation_record.activations)=}" | ||
| normalized_activations = normalize_activations( | ||
| activation_record.activations, max_activation=max_activation | ||
| ) | ||
| return json.dumps({ | ||
| "neuron": neuron, | ||
| "explanation": explanation, | ||
| "activations": [ | ||
| { | ||
| "token": token, | ||
| "activation": normalized_activations[i] if include_activations else None | ||
| } for i, token in enumerate(activation_record.tokens) | ||
| ] | ||
| }) | ||
|
|
||
| def _parse_no_logprobs_completion_json( | ||
| completion: str, | ||
| tokens: Sequence[str], | ||
| ) -> Sequence[float]: | ||
| """ | ||
| Parse a completion into a list of simulated activations. If the model did not faithfully | ||
| reproduce the token sequence, return a list of 0s. If the model's activation for a token | ||
| is not a number between 0 and 10 (inclusive), substitute 0. | ||
|
|
||
| Args: | ||
| completion: completion from the API | ||
| tokens: list of tokens as strings in the sequence where the neuron is being simulated | ||
| """ | ||
|
|
||
| logger.debug("for tokens:\n%s", tokens) | ||
| logger.debug("received completion:\n%s", completion) | ||
|
|
||
| zero_prediction = [0] * len(tokens) | ||
|
|
||
| try: | ||
| completion = json.loads(completion) | ||
| if "activations" not in completion: | ||
| logger.error("The key 'activations' is not in the completion:\n%s\nExpected Tokens:\n%s", json.dumps(completion), tokens) | ||
| return zero_prediction | ||
| activations = completion["activations"] | ||
| if len(activations) != len(tokens): | ||
| logger.error("Tokens and activations length did not match:\n%s\nExpected Tokens:\n%s", json.dumps(completion), tokens) | ||
| return zero_prediction | ||
| predicted_activations = [] | ||
| # check that there is a token and activation value | ||
| # no need to double check the token matches exactly | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where's the first check?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I should have placed this comment one line lower. If you mean the token and activation value check: Or do you mean a different check?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I'm understanding the code correctly it looks like you only check that the number of tokens is as expected, but you don't check that any of the tokens individually are correct. is that right? |
||
| for i, activation in enumerate(activations): | ||
| if "token" not in activation: | ||
| logger.error("The key 'token' is not in activation:\n%s\nCompletion:%s\nExpected Tokens:\n%s", activation, json.dumps(completion), tokens) | ||
| predicted_activations.append(0) | ||
| continue | ||
| if "activation" not in activation: | ||
| logger.error("The key 'activation' is not in activation:\n%s\nCompletion:%s\nExpected Tokens:\n%s", activation, json.dumps(completion), tokens) | ||
| predicted_activations.append(0) | ||
| continue | ||
| # Ensure activation value is between 0-10 inclusive | ||
| try: | ||
| predicted_activation_float = float(activation["activation"]) | ||
| if predicted_activation_float < 0 or predicted_activation_float > MAX_NORMALIZED_ACTIVATION: | ||
| logger.error("activation value out of range: %s\nCompletion:%s\nExpected Tokens:\n%s", predicted_activation_float, json.dumps(completion), tokens) | ||
| predicted_activations.append(0) | ||
| else: | ||
| predicted_activations.append(predicted_activation_float) | ||
| except ValueError: | ||
| logger.error("activation value invalid: %s\nCompletion:%s\nExpected Tokens:\n%s", activation["activation"], json.dumps(completion), tokens) | ||
| predicted_activations.append(0) | ||
| except TypeError: | ||
| logger.error("activation value incorrect type: %s\nCompletion:%s\nExpected Tokens:\n%s", activation["activation"], json.dumps(completion), tokens) | ||
| predicted_activations.append(0) | ||
| logger.debug("predicted activations: %s", predicted_activations) | ||
| return predicted_activations | ||
|
|
||
| except json.JSONDecodeError: | ||
| logger.error("Failed to parse completion JSON:\n%s\nExpected Tokens:\n%s", completion, tokens) | ||
| return zero_prediction | ||
|
|
||
| def _parse_no_logprobs_completion( | ||
| completion: str, | ||
| tokens: Sequence[str], | ||
| ) -> Sequence[int]: | ||
| ) -> Sequence[float]: | ||
| """ | ||
| Parse a completion into a list of simulated activations. If the model did not faithfully | ||
| reproduce the token sequence, return a list of 0s. If the model's activation for a token | ||
| is not an integer betwee 0 and 10, substitute 0. | ||
| is not a number between 0 and 10 (inclusive), substitute 0. | ||
|
|
||
| Args: | ||
| completion: completion from the API | ||
| tokens: list of tokens as strings in the sequence where the neuron is being simulated | ||
| """ | ||
|
|
||
| logger.debug("for tokens:\n%s", tokens) | ||
| logger.debug("received completion:\n%s", completion) | ||
|
|
||
| zero_prediction = [0] * len(tokens) | ||
| token_lines = completion.strip("\n").split("༗\n") | ||
| # FIX: Strip the last ༗\n, otherwise all last activations are invalid | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch! let's remove the |
||
| token_lines = completion.strip("\n").strip("༗\n").split("༗\n") | ||
| # Edge Case #2: Sometimes GPT doesn't use the special character when it answers, it only uses the \n" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Line breaks are fairly common. How often do we get cases where GPT doesn't use the special character and the text doesn't contain |
||
| # The fix is to try splitting by \n if we detect that the response isn't the right format | ||
| # TODO: If there are also line breaks in the text, this will probably break | ||
| if (len(token_lines)) == 1: | ||
| token_lines = completion.strip("\n").strip("༗\n").split("\n") | ||
| logger.debug("parsed completion into token_lines as:\n%s", token_lines) | ||
|
|
||
| start_line_index = None | ||
| for i, token_line in enumerate(token_lines): | ||
| if token_line.startswith(f"{tokens[0]}\t"): | ||
| if (token_line.startswith(f"{tokens[0]}\t") | ||
| # Edge Case #1: GPT often omits the space before the first token. | ||
| # Allow the returned token line to be either " token" or "token". | ||
| or f" {token_line}".startswith(f"{tokens[0]}\t") | ||
| # Edge Case #3: Allow our "not end token" replacement | ||
| or (token_line.startswith(END_OF_TEXT_TOKEN_REPLACEMENT) and tokens[0].strip() == END_OF_TEXT_TOKEN) | ||
| ): | ||
| logger.debug("start_line_index is: %s", start_line_index) | ||
| logger.debug("matched token %s with token_line %s", tokens[0], token_line) | ||
| start_line_index = i | ||
| break | ||
|
|
||
| # If we didn't find the first token, or if the number of lines in the completion doesn't match | ||
| # the number of tokens, return a list of 0s. | ||
| if start_line_index is None or len(token_lines) - start_line_index != len(tokens): | ||
| logger.debug("didn't find first token or number of lines didn't match, returning all zeroes") | ||
| return zero_prediction | ||
|
|
||
| predicted_activations = [] | ||
| for i, token_line in enumerate(token_lines[start_line_index:]): | ||
| if not token_line.startswith(f"{tokens[i]}\t"): | ||
| if (not token_line.startswith(f"{tokens[i]}\t") | ||
| # Edge Case #1: GPT often omits the space before the token. | ||
| # Allow the returned token line to be either " token" or "token". | ||
| and not f" {token_line}".startswith(f"{tokens[i]}\t") | ||
| # Edge Case #3: Allow our "not end token" replacement | ||
| and not token_line.startswith(END_OF_TEXT_TOKEN_REPLACEMENT) | ||
| ): | ||
| logger.debug("failed to match token %s with token_line %s, returning all zeroes", tokens[i], token_line) | ||
| return zero_prediction | ||
| predicted_activation = token_line.split("\t")[1] | ||
| if predicted_activation not in VALID_ACTIVATION_TOKENS: | ||
| predicted_activation_split = token_line.split("\t") | ||
| # Ensure token line has correct size after splitting. If not then assume it's a zero. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. feels like there's a better way to do this, since I imagine tabs aren't rare. maybe we could split on tabs and take the last element in the list? and then if there's a problem with the result, it will be caught by the activation parsing code below? |
||
| if len(predicted_activation_split) != 2: | ||
| logger.debug("tokenline split invalid size: %s", token_line) | ||
| predicted_activations.append(0) | ||
| else: | ||
| predicted_activations.append(int(predicted_activation)) | ||
| continue | ||
| predicted_activation = predicted_activation_split[1] | ||
| # Sometimes GPT the activation value is not a float (GPT likes to append an extra ༗). | ||
| # In all cases if the activation is not numerically parseable, set it to 0 | ||
| try: | ||
| predicted_activation_float = float(predicted_activation) | ||
| if predicted_activation_float < 0 or predicted_activation_float > MAX_NORMALIZED_ACTIVATION: | ||
| logger.debug("activation value out of range: %s", predicted_activation_float) | ||
| predicted_activations.append(0) | ||
| else: | ||
| predicted_activations.append(predicted_activation_float) | ||
| except ValueError: | ||
| logger.debug("activation value not numeric: %s", predicted_activation) | ||
| predicted_activations.append(0) | ||
| logger.debug("predicted activations: %s", predicted_activations) | ||
| return predicted_activations | ||
|
|
||
|
|
||
| class LogprobFreeExplanationTokenSimulator(NeuronSimulator): | ||
| """ | ||
| Simulate neuron behavior based on an explanation. | ||
|
|
@@ -695,6 +831,7 @@ def __init__( | |
| model_name: str, | ||
| explanation: str, | ||
| max_concurrent: Optional[int] = 10, | ||
| json_mode: Optional[bool] = True, | ||
| few_shot_example_set: FewShotExampleSet = FewShotExampleSet.NEWER, | ||
| prompt_format: PromptFormat = PromptFormat.HARMONY_V4, | ||
| cache: bool = False, | ||
|
|
@@ -705,6 +842,7 @@ def __init__( | |
| self.api_client = ApiClient( | ||
| model_name=model_name, max_concurrent=max_concurrent, cache=cache | ||
| ) | ||
| self.json_mode = json_mode | ||
| self.explanation = explanation | ||
| self.few_shot_example_set = few_shot_example_set | ||
| self.prompt_format = prompt_format | ||
|
|
@@ -713,24 +851,30 @@ async def simulate( | |
| self, | ||
| tokens: Sequence[str], | ||
| ) -> SequenceSimulation: | ||
| prompt = self._make_simulation_prompt( | ||
| tokens, | ||
| self.explanation, | ||
| ) | ||
| response = await self.api_client.make_request( | ||
| prompt=prompt, echo=False, max_tokens=1000 | ||
| ) | ||
| assert len(response["choices"]) == 1 | ||
|
|
||
| choice = response["choices"][0] | ||
| if self.prompt_format == PromptFormat.HARMONY_V4: | ||
| if self.json_mode: | ||
| prompt = self._make_simulation_prompt_json( | ||
| tokens, | ||
| self.explanation, | ||
| ) | ||
| response = await self.api_client.make_request( | ||
| messages=prompt, max_tokens=2000, temperature=0, json_mode=True | ||
| ) | ||
| assert len(response["choices"]) == 1 | ||
| choice = response["choices"][0] | ||
| completion = choice["message"]["content"] | ||
| elif self.prompt_format in [PromptFormat.NONE, PromptFormat.INSTRUCTION_FOLLOWING]: | ||
| completion = choice["text"] | ||
| predicted_activations = _parse_no_logprobs_completion_json(completion, tokens) | ||
| else: | ||
| raise ValueError(f"Unhandled prompt format {self.prompt_format}") | ||
|
|
||
| predicted_activations = _parse_no_logprobs_completion(completion, tokens) | ||
| prompt = self._make_simulation_prompt( | ||
| tokens, | ||
| self.explanation, | ||
| ) | ||
| response = await self.api_client.make_request( | ||
| messages=prompt, max_tokens=1000, temperature=0 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's make |
||
| ) | ||
| assert len(response["choices"]) == 1 | ||
| choice = response["choices"][0] | ||
| completion = choice["message"]["content"] | ||
| predicted_activations = _parse_no_logprobs_completion(completion, tokens) | ||
|
|
||
| result = SequenceSimulation( | ||
| activation_scale=ActivationScale.SIMULATED_NORMALIZED_ACTIVATIONS, | ||
|
|
@@ -743,6 +887,80 @@ async def simulate( | |
| logger.debug("result in score_explanation_by_activations is %s", result) | ||
| return result | ||
|
|
||
| def _make_simulation_prompt_json( | ||
| self, | ||
| tokens: Sequence[str], | ||
| explanation: str, | ||
| ) -> Union[str, list[HarmonyMessage]]: | ||
| """Make a few-shot prompt for predicting the neuron's activations on a sequence.""" | ||
| """NOTE: The JSON version does not give GPT multiple sequence examples per neuron.""" | ||
| assert explanation != "" | ||
| prompt_builder = PromptBuilder() | ||
| prompt_builder.add_message( | ||
| Role.SYSTEM, | ||
| """We're studying neurons in a neural network. Each neuron looks for some particular thing in a short document. Look at an explanation of what the neuron does, and try to predict its activations on a particular token. | ||
|
|
||
| For each sequence, you will see the tokens in the sequence where the activations are left blank. You will print, in valid json, the exact same tokens verbatim, but with the activation values filled in according to the explanation. | ||
| Fill out the activation values from 0 to 10. Most activations will be 0. | ||
| """, | ||
| ) | ||
|
|
||
| few_shot_examples = self.few_shot_example_set.get_examples() | ||
| for i, example in enumerate(few_shot_examples): | ||
| few_shot_example_max_activation = calculate_max_activation(example.activation_records) | ||
| """ | ||
| { | ||
| "neuron": 1, | ||
| // The explanation for the neuron behavior | ||
| "explanation": "hello" | ||
| // Fill out the activation with a value from 0 to 10. Most activations will be 0. | ||
| "activations": [ | ||
| { | ||
| "token": "The", | ||
| "activation": null | ||
| } | ||
| ] | ||
| } | ||
| """ | ||
| prompt_builder.add_message( | ||
| Role.USER, | ||
| _format_record_for_logprob_free_simulation_json(i + 1, explanation=example.explanation, activation_record=example.activation_records[0], include_activations=False) | ||
| ) | ||
| """ | ||
| { | ||
| "neuron": 1, | ||
| "explanation": "hello" | ||
| "activations": [ | ||
| { | ||
| "token": "The", | ||
| "activation": 3 | ||
| } | ||
| ] | ||
| } | ||
| """ | ||
| prompt_builder.add_message( | ||
| Role.ASSISTANT, | ||
| _format_record_for_logprob_free_simulation_json(i + 1, explanation=example.explanation, activation_record=example.activation_records[0], include_activations=True, max_activation=few_shot_example_max_activation) | ||
| ) | ||
| neuron_index = len(few_shot_examples) + 1 | ||
| """ | ||
| { | ||
| "neuron": 3, | ||
| "explanation": "hello" | ||
| "activations": [ | ||
| { | ||
| "token": "The", | ||
| "activation": null | ||
| } | ||
| ] | ||
| } | ||
| """ | ||
| prompt_builder.add_message( | ||
| Role.USER, | ||
| _format_record_for_logprob_free_simulation_json(neuron_index, explanation=explanation, activation_record=ActivationRecord(tokens=tokens, activations=[]), include_activations=False) | ||
| ) | ||
| return prompt_builder.build(self.prompt_format, allow_extra_system_messages=True) | ||
|
|
||
| def _make_simulation_prompt( | ||
| self, | ||
| tokens: Sequence[str], | ||
|
|
@@ -753,7 +971,7 @@ def _make_simulation_prompt( | |
| prompt_builder = PromptBuilder(allow_extra_system_messages=True) | ||
| prompt_builder.add_message( | ||
| Role.SYSTEM, | ||
| """We're studying neurons in a neural network. Each neuron looks for some particular thing in a short document. Look at an explanation of what the neuron does, and try to predict its activations on a particular token. | ||
| """We're studying neurons in a neural network. Each neuron looks for some particular thing in a short document. Look at an explanation of what the neuron does, and try to predict its activations on a particular token. | ||
|
|
||
| The activation format is token<tab>activation, and activations range from 0 to 10. Most activations will be 0. | ||
| For each sequence, you will see the tokens in the sequence where the activations are left blank. You will print the exact same tokens verbatim, but with the activations filled in according to the explanation. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you think it would help to add
indent=2(or4) here? maybe gpt3.5 is better at parsing indented blocks of json rather than having everything on the same line.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could help and definitely worth testing. IMO I haven't seen GPT have trouble parsing unindented JSON, though I can't prove it.
The main problem we ran into is GPT gets confused by non-ascii characters (it stops generating abruptly, or it will add extra non-existent tokens). For example, the bullet point, ellipses, pound symbol, etc. I should open a separate issue for this but it's not really fixable by this repo.
But I think OpenAI is aware and trying to fix:
https://community.openai.com/t/gpt-4-1106-preview-is-not-generating-utf-8/482839/6
https://community.openai.com/t/gpt-4-1106-preview-messes-up-function-call-parameters-encoding/478500/36?page=2
In the meantime our workaround is to double-escape non-ascii chars BEFORE we feed it to automated-interpretability:
hijohnnylin/neuronpedia-scorer@21c07d8
E.g, "\u2022" becomes "\\u2022"
I decided to put those "pre-processing" changes outside of this repo, since it's a temporary workaround until OpenAI fixes it - but lmk if you think it should be here instead. Can also make it an additional flag like
replace_non_asciior something.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if it seems like GPT doesn't have trouble with unindented JSON then it's probably fine to leave it. in terms of the non-ascii characters, how prevalent is this problem? i.e. if you try to use the simulator (either in plaintext or json mode) how often does it bork the result?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies for delay - have been working on other parts of NP.
It borks it quite frequently on non-ascii in json mode (not sure about plaintext). I started out by special casing each character, but after 5-6 special cases it was apparent that it needed to all be excluded (that's also when I found the OpenAI community threads).
Here's a code sample to reproduce the issue. Run with the latest
api_clientthat supports json mode. GPT ends up returning truncated json that's unparseable.You should get the following output and error:
Here's another reproducible example - replace
to_send'sactivationswith the array below. We give GPT 63 activations and only receive 62 back - the ellipses symbol (…) token and activation are missing from the response. In this version of the bug, GPT just silently eliminates the token.You should see the following output, which shows the incorrect length of GPT's response.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no worries about the delay and thanks for the thorough instructions on reproducing the problem! in that case, I think it would be great to put the non-ascii preprocessing code in this file and add a flag to enable it like you suggested 🙏