From 73eceb7a58738a45e928e3962365f8d359626c1b Mon Sep 17 00:00:00 2001 From: williamrs-openai <78386475+williamrs-openai@users.noreply.github.com> Date: Thu, 28 Dec 2023 11:20:21 -0500 Subject: [PATCH 1/2] Update prompt_builder.py --- .../neuron_explainer/explanations/prompt_builder.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/neuron-explainer/neuron_explainer/explanations/prompt_builder.py b/neuron-explainer/neuron_explainer/explanations/prompt_builder.py index 3782940..cc6f655 100644 --- a/neuron-explainer/neuron_explainer/explanations/prompt_builder.py +++ b/neuron-explainer/neuron_explainer/explanations/prompt_builder.py @@ -49,8 +49,14 @@ class Role(str, Enum): class PromptBuilder: """Class for accumulating components of a prompt and then formatting them into an output.""" - def __init__(self) -> None: + + def __init__(self, allow_extra_system_messages: bool = False) -> None: + """The `allow_extra_system_messages` instance variable allows the caller to specify that the prompt + should be allowed to contain system messages after the very first one.""" + self._messages: list[HarmonyMessage] = [] + self._allow_extra_system_messages = allow_extra_system_messages + self.renderer = get_renderer("harmony_v4.0_no_system_message_content_type") def add_message(self, role: Role, message: str) -> None: self._messages.append(HarmonyMessage(role=role, content=message)) @@ -93,7 +99,7 @@ def build( for message in messages: role = message["role"] assert role == expected_next_role or ( - allow_extra_system_messages and role == Role.SYSTEM + (self._allow_extra_system_messages or allow_extra_system_messages) and role == Role.SYSTEM ), f"Expected message from {expected_next_role} but got message from {role}" if role == Role.SYSTEM: expected_next_role = Role.USER From 6cb468a736d4c171ab3fdaf90d2af1feb928b497 Mon Sep 17 00:00:00 2001 From: williamrs-openai <78386475+williamrs-openai@users.noreply.github.com> Date: Thu, 28 Dec 2023 11:21:25 -0500 Subject: [PATCH 2/2] Update prompt_builder.py --- neuron-explainer/neuron_explainer/explanations/prompt_builder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/neuron-explainer/neuron_explainer/explanations/prompt_builder.py b/neuron-explainer/neuron_explainer/explanations/prompt_builder.py index cc6f655..8776e7f 100644 --- a/neuron-explainer/neuron_explainer/explanations/prompt_builder.py +++ b/neuron-explainer/neuron_explainer/explanations/prompt_builder.py @@ -56,7 +56,6 @@ def __init__(self, allow_extra_system_messages: bool = False) -> None: self._messages: list[HarmonyMessage] = [] self._allow_extra_system_messages = allow_extra_system_messages - self.renderer = get_renderer("harmony_v4.0_no_system_message_content_type") def add_message(self, role: Role, message: str) -> None: self._messages.append(HarmonyMessage(role=role, content=message))