Skip to content

Commit

Permalink
Migrate from template to prompt arg while keeping backward compat…
Browse files Browse the repository at this point in the history
…ibility (#1066)
  • Loading branch information
sidmohanty11 authored Dec 28, 2023
1 parent 12e6eaf commit d9d5299
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 42 deletions.
2 changes: 1 addition & 1 deletion configs/full-stack.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ llm:
max_tokens: 1000
top_p: 1
stream: false
template: |
prompt: |
Use the following pieces of context to answer the query at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Expand Down
8 changes: 4 additions & 4 deletions docs/api-reference/advanced/configuration.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ llm:
top_p: 1
stream: false
api_key: sk-xxx
template: |
prompt: |
Use the following pieces of context to answer the query at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Expand Down Expand Up @@ -73,7 +73,7 @@ chunker:
"max_tokens": 1000,
"top_p": 1,
"stream": false,
"template": "Use the following pieces of context to answer the query at the end.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n$context\n\nQuery: $query\n\nHelpful Answer:",
"prompt": "Use the following pieces of context to answer the query at the end.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n$context\n\nQuery: $query\n\nHelpful Answer:",
"system_prompt": "Act as William Shakespeare. Answer the following questions in the style of William Shakespeare.",
"api_key": "sk-xxx"
}
Expand Down Expand Up @@ -117,7 +117,7 @@ config = {
'max_tokens': 1000,
'top_p': 1,
'stream': False,
'template': (
'prompt': (
"Use the following pieces of context to answer the query at the end.\n"
"If you don't know the answer, just say that you don't know, don't try to make up an answer.\n"
"$context\n\nQuery: $query\n\nHelpful Answer:"
Expand Down Expand Up @@ -170,7 +170,7 @@ Alright, let's dive into what each key means in the yaml config above:
- `max_tokens` (Integer): Controls how many tokens are used in the response.
- `top_p` (Float): Controls the diversity of word selection. A higher value (closer to 1) makes word selection more diverse.
- `stream` (Boolean): Controls if the response is streamed back to the user (set to false).
- `template` (String): A custom template for the prompt that the model uses to generate responses.
- `prompt` (String): A prompt for the model to follow when generating responses, requires $context and $query variables.
- `system_prompt` (String): A system prompt for the model to follow when generating responses, in this case, it's set to the style of William Shakespeare.
- `stream` (Boolean): Controls if the response is streamed back to the user (set to false).
- `number_documents` (Integer): Number of documents to pull from the vectordb as context, defaults to 1
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/rest-api/create.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ llm:
max_tokens: 1000
top_p: 1
stream: false
template: |
prompt: |
Use the following pieces of context to answer the query at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Expand Down
49 changes: 31 additions & 18 deletions embedchain/config/llm/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import re
from string import Template
from typing import Any, Dict, List, Optional
Expand Down Expand Up @@ -59,6 +60,7 @@ def __init__(
self,
number_documents: int = 3,
template: Optional[Template] = None,
prompt: Optional[Template] = None,
model: Optional[str] = None,
temperature: float = 0,
max_tokens: int = 1000,
Expand All @@ -80,8 +82,11 @@ def __init__(
context, defaults to 1
:type number_documents: int, optional
:param template: The `Template` instance to use as a template for
prompt, defaults to None
prompt, defaults to None (deprecated)
:type template: Optional[Template], optional
:param prompt: The `Template` instance to use as a template for
prompt, defaults to None
:type prompt: Optional[Template], optional
:param model: Controls the OpenAI model used, defaults to None
:type model: Optional[str], optional
:param temperature: Controls the randomness of the model's output.
Expand All @@ -106,8 +111,16 @@ def __init__(
contain $context and $query (and optionally $history)
:raises ValueError: Stream is not boolean
"""
if template is None:
template = DEFAULT_PROMPT_TEMPLATE
if template is not None:
logging.warning(
"The `template` argument is deprecated and will be removed in a future version. "
+ "Please use `prompt` instead."
)
if prompt is None:
prompt = template

if prompt is None:
prompt = DEFAULT_PROMPT_TEMPLATE

self.number_documents = number_documents
self.temperature = temperature
Expand All @@ -120,37 +133,37 @@ def __init__(
self.callbacks = callbacks
self.api_key = api_key

if type(template) is str:
template = Template(template)
if type(prompt) is str:
prompt = Template(prompt)

if self.validate_template(template):
self.template = template
if self.validate_prompt(prompt):
self.prompt = prompt
else:
raise ValueError("`template` should have `query` and `context` keys and potentially `history` (if used).")
raise ValueError("The 'prompt' should have 'query' and 'context' keys and potentially 'history' (if used).")

if not isinstance(stream, bool):
raise ValueError("`stream` should be bool")
self.stream = stream
self.where = where

def validate_template(self, template: Template) -> bool:
def validate_prompt(self, prompt: Template) -> bool:
"""
validate the template
validate the prompt
:param template: the template to validate
:type template: Template
:param prompt: the prompt to validate
:type prompt: Template
:return: valid (true) or invalid (false)
:rtype: bool
"""
return re.search(query_re, template.template) and re.search(context_re, template.template)
return re.search(query_re, prompt.template) and re.search(context_re, prompt.template)

def _validate_template_history(self, template: Template) -> bool:
def _validate_prompt_history(self, prompt: Template) -> bool:
"""
validate the template with history
validate the prompt with history
:param template: the template to validate
:type template: Template
:param prompt: the prompt to validate
:type prompt: Template
:return: valid (true) or invalid (false)
:rtype: bool
"""
return re.search(history_re, template.template)
return re.search(history_re, prompt.template)
28 changes: 14 additions & 14 deletions embedchain/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,19 @@ def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[
if web_search_result:
context_string = self._append_search_and_context(context_string, web_search_result)

template_contains_history = self.config._validate_template_history(self.config.template)
if template_contains_history:
# Template contains history
prompt_contains_history = self.config._validate_prompt_history(self.config.prompt)
if prompt_contains_history:
# Prompt contains history
# If there is no history yet, we insert `- no history -`
prompt = self.config.template.substitute(
prompt = self.config.prompt.substitute(
context=context_string, query=input_query, history=self.history or "- no history -"
)
elif self.history and not template_contains_history:
# History is present, but not included in the template.
# check if it's the default template without history
elif self.history and not prompt_contains_history:
# History is present, but not included in the prompt.
# check if it's the default prompt without history
if (
not self.config._validate_template_history(self.config.template)
and self.config.template.template == DEFAULT_PROMPT
not self.config._validate_prompt_history(self.config.prompt)
and self.config.prompt.template == DEFAULT_PROMPT
):
# swap in the template with history
prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
Expand All @@ -95,12 +95,12 @@ def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[
else:
# If we can't swap in the default, we still proceed but tell users that the history is ignored.
logging.warning(
"Your bot contains a history, but template does not include `$history` key. History is ignored."
"Your bot contains a history, but prompt does not include `$history` key. History is ignored."
)
prompt = self.config.template.substitute(context=context_string, query=input_query)
prompt = self.config.prompt.substitute(context=context_string, query=input_query)
else:
# basic use case, no history.
prompt = self.config.template.substitute(context=context_string, query=input_query)
prompt = self.config.prompt.substitute(context=context_string, query=input_query)
return prompt

def _append_search_and_context(self, context: str, web_search_result: str) -> str:
Expand Down Expand Up @@ -191,7 +191,7 @@ def query(self, input_query: str, contexts: List[str], config: BaseLlmConfig = N
return contexts

if self.is_docs_site_instance:
self.config.template = DOCS_SITE_PROMPT_TEMPLATE
self.config.prompt = DOCS_SITE_PROMPT_TEMPLATE
self.config.number_documents = 5
k = {}
if self.online:
Expand Down Expand Up @@ -242,7 +242,7 @@ def chat(self, input_query: str, contexts: List[str], config: BaseLlmConfig = No
self.config = config

if self.is_docs_site_instance:
self.config.template = DOCS_SITE_PROMPT_TEMPLATE
self.config.prompt = DOCS_SITE_PROMPT_TEMPLATE
self.config.number_documents = 5
k = {}
if self.online:
Expand Down
1 change: 1 addition & 0 deletions embedchain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ def validate_config(config_data):
Optional("top_p"): Or(float, int),
Optional("stream"): bool,
Optional("template"): str,
Optional("prompt"): str,
Optional("system_prompt"): str,
Optional("deployment_name"): str,
Optional("where"): dict,
Expand Down
2 changes: 1 addition & 1 deletion tests/helper_classes/test_json_serializable.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,4 @@ def test_special_subclasses(self):
config = BaseLlmConfig(template=Template("My custom template with $query, $context and $history."))
s = config.serialize()
new_config: BaseLlmConfig = BaseLlmConfig.deserialize(s)
self.assertEqual(config.template.template, new_config.template.template)
self.assertEqual(config.prompt.template, new_config.prompt.template)
2 changes: 1 addition & 1 deletion tests/llm/test_base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_is_stream_bool():
def test_template_string_gets_converted_to_Template_instance():
config = BaseLlmConfig(template="test value $query $context")
llm = BaseLlm(config=config)
assert isinstance(llm.config.template, Template)
assert isinstance(llm.config.prompt, Template)


def test_is_get_llm_model_answer_implemented():
Expand Down
4 changes: 2 additions & 2 deletions tests/llm/test_generate_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ def test_generate_prompt_with_contexts_list(self):
result = self.app.llm.generate_prompt(input_query, contexts)

# Assert
expected_result = config.template.substitute(context="Context 1 | Context 2 | Context 3", query=input_query)
expected_result = config.prompt.substitute(context="Context 1 | Context 2 | Context 3", query=input_query)
self.assertEqual(result, expected_result)

def test_generate_prompt_with_history(self):
"""
Test the 'generate_prompt' method with BaseLlmConfig containing a history attribute.
"""
config = BaseLlmConfig()
config.template = Template("Context: $context | Query: $query | History: $history")
config.prompt = Template("Context: $context | Query: $query | History: $history")
self.app.llm.config = config
self.app.llm.set_history(["Past context 1", "Past context 2"])
prompt = self.app.llm.generate_prompt("Test query", ["Test context"])
Expand Down

0 comments on commit d9d5299

Please sign in to comment.