Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weave agent personality through subtasks #916

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/khoj/processor/conversation/anthropic/anthropic_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def extract_questions_anthropic(
temperature=0.7,
location_data: LocationData = None,
user: KhojUser = None,
personality_context: Optional[str] = None,
):
"""
Infer search queries to retrieve relevant notes to answer user query
Expand Down Expand Up @@ -59,6 +60,7 @@ def extract_questions_anthropic(
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location,
username=username,
personality_context=personality_context,
)

prompt = prompts.extract_questions_anthropic_user_message.format(
Expand Down
2 changes: 2 additions & 0 deletions src/khoj/processor/conversation/google/gemini_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def extract_questions_gemini(
max_tokens=None,
location_data: LocationData = None,
user: KhojUser = None,
personality_context: Optional[str] = None,
):
"""
Infer search queries to retrieve relevant notes to answer user query
Expand Down Expand Up @@ -60,6 +61,7 @@ def extract_questions_gemini(
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location,
username=username,
personality_context=personality_context,
)

prompt = prompts.extract_questions_anthropic_user_message.format(
Expand Down
4 changes: 3 additions & 1 deletion src/khoj/processor/conversation/offline/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from datetime import datetime, timedelta
from threading import Thread
from typing import Any, Iterator, List, Union
from typing import Any, Iterator, List, Optional, Union

from langchain.schema import ChatMessage
from llama_cpp import Llama
Expand Down Expand Up @@ -33,6 +33,7 @@ def extract_questions_offline(
user: KhojUser = None,
max_prompt_size: int = None,
temperature: float = 0.7,
personality_context: Optional[str] = None,
) -> List[str]:
"""
Infer search queries to retrieve relevant notes to answer user query
Expand Down Expand Up @@ -73,6 +74,7 @@ def extract_questions_offline(
this_year=today.year,
location=location,
username=username,
personality_context=personality_context,
)

messages = generate_chatml_messages_with_context(
Expand Down
2 changes: 2 additions & 0 deletions src/khoj/processor/conversation/openai/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def extract_questions(
user: KhojUser = None,
uploaded_image_url: Optional[str] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None,
):
"""
Infer search queries to retrieve relevant notes to answer user query
Expand Down Expand Up @@ -68,6 +69,7 @@ def extract_questions(
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location,
username=username,
personality_context=personality_context,
)

prompt = construct_structured_message(
Expand Down
22 changes: 18 additions & 4 deletions src/khoj/processor/conversation/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@

image_generation_improve_prompt_base = """
You are a talented media artist with the ability to describe images to compose in professional, fine detail.
{personality_context}
Generate a vivid description of the image to be rendered using the provided context and user prompt below:

Today's Date: {current_date}
Expand Down Expand Up @@ -210,6 +211,7 @@
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.
- Share relevant search queries as a JSON list of strings. Do not say anything else.
{personality_context}

Current Date: {day_of_week}, {current_date}
User's Location: {location}
Expand Down Expand Up @@ -260,7 +262,7 @@
- Break messages into multiple search queries when required to retrieve the relevant information.
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.

{personality_context}
What searches will you perform to answer the users question? Respond with search queries as list of strings in a JSON object.
Current Date: {day_of_week}, {current_date}
User's Location: {location}
Expand Down Expand Up @@ -317,7 +319,7 @@
- Break messages into multiple search queries when required to retrieve the relevant information.
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.

{personality_context}
What searches will you perform to answer the users question? Respond with a JSON object with the key "queries" mapping to a list of searches you would perform on the user's knowledge base. Just return the queries and nothing else.

Current Date: {day_of_week}, {current_date}
Expand Down Expand Up @@ -375,6 +377,7 @@

extract_relevant_information = PromptTemplate.from_template(
"""
{personality_context}
Target Query: {query}

Web Pages:
Expand All @@ -400,6 +403,7 @@

extract_relevant_summary = PromptTemplate.from_template(
"""
{personality_context}
Target Query: {query}

Document Contents:
Expand All @@ -409,9 +413,18 @@
""".strip()
)

personality_context = PromptTemplate.from_template(
"""
Here's some additional context about you:
{personality}

"""
)

pick_relevant_output_mode = PromptTemplate.from_template(
"""
You are Khoj, an excellent analyst for selecting the correct way to respond to a user's query.
{personality_context}
You have access to a limited set of modes for your response.
You can only use one of these modes.

Expand Down Expand Up @@ -464,6 +477,7 @@
pick_relevant_information_collection_tools = PromptTemplate.from_template(
"""
You are Khoj, an extremely smart and helpful search assistant.
{personality_context}
- You have access to a variety of data sources to help you answer the user's question
- You can use the data sources listed below to collect more relevant information
- You can use any combination of these data sources to answer the user's question
Expand Down Expand Up @@ -538,7 +552,7 @@
- Add as much context from the previous questions and answers as required to construct the webpage urls.
- Use multiple web page urls if required to retrieve the relevant information.
- You have access to the the whole internet to retrieve information.

{personality_context}
Which webpages will you need to read to answer the user's question?
Provide web page links as a list of strings in a JSON object.
Current Date: {current_date}
Expand Down Expand Up @@ -585,7 +599,7 @@
- Use site: google search operator when appropriate
- You have access to the the whole internet to retrieve information.
- Official, up-to-date information about you, Khoj, is available at site:khoj.dev, github or pypi.

{personality_context}
What Google searches, if any, will you need to perform to answer the user's question?
Provide search queries as a list of strings in a JSON object.
Current Date: {current_date}
Expand Down
4 changes: 3 additions & 1 deletion src/khoj/processor/image/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import requests

from khoj.database.adapters import ConversationAdapters
from khoj.database.models import KhojUser, TextToImageModelConfig
from khoj.database.models import Agent, KhojUser, TextToImageModelConfig
from khoj.routers.helpers import ChatEvent, generate_better_image_prompt
from khoj.routers.storage import upload_image
from khoj.utils import state
Expand All @@ -28,6 +28,7 @@ async def text_to_image(
subscribed: bool = False,
send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None,
agent: Agent = None,
):
status_code = 200
image = None
Expand Down Expand Up @@ -67,6 +68,7 @@ async def text_to_image(
model_type=text_to_image_config.model_type,
subscribed=subscribed,
uploaded_image_url=uploaded_image_url,
agent=agent,
)

if send_status_func:
Expand Down
16 changes: 9 additions & 7 deletions src/khoj/processor/tools/online_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from bs4 import BeautifulSoup
from markdownify import markdownify

from khoj.database.models import KhojUser
from khoj.database.models import Agent, KhojUser
from khoj.routers.helpers import (
ChatEvent,
extract_relevant_info,
Expand Down Expand Up @@ -58,16 +58,17 @@ async def search_online(
send_status_func: Optional[Callable] = None,
custom_filters: List[str] = [],
uploaded_image_url: str = None,
agent: Agent = None,
):
query += " ".join(custom_filters)
if not is_internet_connected():
logger.warn("Cannot search online as not connected to internet")
logger.warning("Cannot search online as not connected to internet")
yield {}
return

# Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries(
query, conversation_history, location, user, uploaded_image_url=uploaded_image_url
query, conversation_history, location, user, uploaded_image_url=uploaded_image_url, agent=agent
)
response_dict = {}

Expand Down Expand Up @@ -102,7 +103,7 @@ async def search_online(
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [
read_webpage_and_extract_content(subquery, link, content, subscribed=subscribed)
read_webpage_and_extract_content(subquery, link, content, subscribed=subscribed, agent=agent)
for link, subquery, content in webpages
]
results = await asyncio.gather(*tasks)
Expand Down Expand Up @@ -143,6 +144,7 @@ async def read_webpages(
subscribed: bool = False,
send_status_func: Optional[Callable] = None,
uploaded_image_url: str = None,
agent: Agent = None,
):
"Infer web pages to read from the query and extract relevant information from them"
logger.info(f"Inferring web pages to read")
Expand All @@ -156,7 +158,7 @@ async def read_webpages(
webpage_links_str = "\n- " + "\n- ".join(list(urls))
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [read_webpage_and_extract_content(query, url, subscribed=subscribed) for url in urls]
tasks = [read_webpage_and_extract_content(query, url, subscribed=subscribed, agent=agent) for url in urls]
results = await asyncio.gather(*tasks)

response: Dict[str, Dict] = defaultdict(dict)
Expand All @@ -167,14 +169,14 @@ async def read_webpages(


async def read_webpage_and_extract_content(
subquery: str, url: str, content: str = None, subscribed: bool = False
subquery: str, url: str, content: str = None, subscribed: bool = False, agent: Agent = None
) -> Tuple[str, Union[None, str], str]:
try:
if is_none_or_empty(content):
with timer(f"Reading web page at '{url}' took", logger):
content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_with_jina(url)
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
extracted_info = await extract_relevant_info(subquery, content, subscribed=subscribed)
extracted_info = await extract_relevant_info(subquery, content, subscribed=subscribed, agent=agent)
return subquery, extracted_info, url
except Exception as e:
logger.error(f"Failed to read web page at '{url}' with {e}")
Expand Down
15 changes: 14 additions & 1 deletion src/khoj/routers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@
get_user_photo,
get_user_search_model_or_default,
)
from khoj.database.models import ChatModelOptions, KhojUser, SpeechToTextModelOptions
from khoj.database.models import (
Agent,
ChatModelOptions,
KhojUser,
SpeechToTextModelOptions,
)
from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.anthropic_chat import (
extract_questions_anthropic,
)
Expand Down Expand Up @@ -333,6 +339,7 @@ async def extract_references_and_questions(
location_data: LocationData = None,
send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None,
agent: Agent = None,
):
user = request.user.object if request.user.is_authenticated else None

Expand Down Expand Up @@ -368,6 +375,8 @@ async def extract_references_and_questions(
using_offline_chat = False
logger.debug(f"Filters in query: {filters_in_query}")

personality_context = prompts.personality_context.format(personality=agent.personality) if agent else ""

# Infer search queries from user message
with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
Expand All @@ -392,6 +401,7 @@ async def extract_references_and_questions(
location_data=location_data,
user=user,
max_prompt_size=conversation_config.max_prompt_size,
personality_context=personality_context,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = conversation_config.openai_config
Expand All @@ -408,6 +418,7 @@ async def extract_references_and_questions(
user=user,
uploaded_image_url=uploaded_image_url,
vision_enabled=vision_enabled,
personality_context=personality_context,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key
Expand All @@ -419,6 +430,7 @@ async def extract_references_and_questions(
conversation_log=meta_log,
location_data=location_data,
user=user,
personality_context=personality_context,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key
Expand All @@ -431,6 +443,7 @@ async def extract_references_and_questions(
location_data=location_data,
max_tokens=conversation_config.max_prompt_size,
user=user,
personality_context=personality_context,
)

# Collate search results as context for GPT
Expand Down
Loading
Loading