Skip to content

Commit 4ea160a

Browse files
authored
Merge pull request #28 from ericmjl/api-server
API Server
2 parents 55ec13e + 4469b35 commit 4ea160a

File tree

16 files changed

+305
-414
lines changed

16 files changed

+305
-414
lines changed

llamabot/bot/chatbot.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def __call__(self, message: str) -> AIMessage:
6565
query=human_message, character_budget=self.response_budget
6666
)
6767
messages = [self.system_prompt] + history + [human_message]
68+
if self.stream:
69+
return self.stream_response(messages)
6870
response = self.generate_response(messages)
6971
autorecord(message, response.content)
7072

llamabot/bot/imagebot.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import requests
66
from pathlib import Path
77
from typing import Optional, Union
8-
from langchain.schema import AIMessage
8+
from llamabot.components.messages import AIMessage
99

1010

1111
class ImageBot:
@@ -34,6 +34,7 @@ def __call__(
3434
If it is empty, then we will generate a filename from the prompt.
3535
:return: The URL of the generated image if running in a Jupyter notebook (str),
3636
otherwise a pathlib.Path object pointing to the generated image.
37+
:raises Exception: If no image URL is found in the response.
3738
"""
3839
response = self.client.images.generate(
3940
model=self.model,
@@ -43,6 +44,8 @@ def __call__(
4344
n=self.n,
4445
)
4546
image_url = response.data[0].url
47+
if not image_url:
48+
raise Exception("No image URL found in response! Please try again.")
4649

4750
# Check if running in a Jupyter notebook
4851
if is_running_in_jupyter():

llamabot/bot/model_dispatcher.py

Lines changed: 0 additions & 172 deletions
This file was deleted.

llamabot/bot/querybot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from llamabot.bot.simplebot import SimpleBot
1010
from llamabot.components.messages import AIMessage, HumanMessage
1111
from llamabot.components.docstore import DocumentStore
12-
12+
from llamabot.components.api import APIMixin
1313
from llamabot.components.messages import (
1414
RetrievedMessage,
1515
retrieve_messages_up_to_budget,
@@ -24,7 +24,7 @@
2424
prompt_recorder_var = contextvars.ContextVar("prompt_recorder")
2525

2626

27-
class QueryBot(SimpleBot, DocumentStore):
27+
class QueryBot(SimpleBot, DocumentStore, APIMixin):
2828
"""QueryBot is a bot that uses simple RAG to answer questions about a document."""
2929

3030
def __init__(

llamabot/bot/simplebot.py

Lines changed: 55 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Class definition for SimpleBot."""
22
import contextvars
3-
from typing import Optional
3+
from typing import Optional, Union
44

55

66
from llamabot.components.messages import (
@@ -29,6 +29,8 @@ class SimpleBot:
2929
:param model_name: The name of the model to use.
3030
:param stream: Whether to stream the output to stdout.
3131
:param json_mode: Whether to print debug messages.
32+
:param api_key: The OpenAI API key to use.
33+
:param mock_response: A mock response to use, for testing purposes only.
3234
"""
3335

3436
def __init__(
@@ -39,56 +41,81 @@ def __init__(
3941
stream=True,
4042
json_mode: bool = False,
4143
api_key: Optional[str] = None,
44+
mock_response: Optional[str] = None,
4245
):
4346
self.system_prompt: SystemMessage = SystemMessage(content=system_prompt)
4447
self.model_name = model_name
4548
self.temperature = temperature
4649
self.stream = stream
4750
self.json_mode = json_mode
4851
self.api_key = api_key
52+
self.mock_response = mock_response
4953

50-
def __call__(self, human_message: str) -> AIMessage:
54+
def __call__(self, human_message: str) -> Union[AIMessage, str]:
5155
"""Call the SimpleBot.
5256
5357
:param human_message: The human message to use.
5458
:return: The response to the human message, primed by the system prompt.
5559
"""
5660

57-
messages: list[BaseMessage] = [
58-
self.system_prompt,
59-
HumanMessage(content=human_message),
60-
]
61+
messages = [self.system_prompt, HumanMessage(content=human_message)]
62+
if self.stream:
63+
return self.stream_response(messages)
6164
response = self.generate_response(messages)
6265
autorecord(human_message, response.content)
6366
return response
6467

6568
def generate_response(self, messages: list[BaseMessage]) -> AIMessage:
66-
"""Generate a response from the given messages."""
67-
68-
messages_dumped: list[dict] = [m.model_dump() for m in messages]
69-
completion_kwargs = dict(
70-
model=self.model_name,
71-
messages=messages_dumped,
72-
temperature=self.temperature,
73-
stream=self.stream,
74-
)
75-
if self.json_mode:
76-
completion_kwargs["response_format"] = {"type": "json_object"}
77-
if self.api_key:
78-
completion_kwargs["api_key"] = self.api_key
79-
response = completion(**completion_kwargs)
69+
"""Generate a response from the given messages.
8070
81-
if self.stream:
82-
ai_message = ""
83-
for chunk in response:
84-
delta = chunk.choices[0].delta.content
85-
if delta is not None:
86-
print(delta, end="")
87-
ai_message += delta
88-
return AIMessage(content=ai_message)
71+
:param messages: A list of messages.
72+
:return: The response to the messages.
73+
"""
8974

75+
response = _make_response(self, messages)
9076
return AIMessage(content=response.choices[0].message.content)
9177

78+
def stream_response(self, messages: list[BaseMessage]) -> str:
79+
"""Stream the response from the given messages.
80+
81+
This is intended to be used with Panel's ChatInterface as part of the callback.
82+
83+
:param messages: A list of messages.
84+
:return: A generator that yields the response.
85+
"""
86+
response = _make_response(self, messages)
87+
message = ""
88+
for chunk in response:
89+
delta = chunk.choices[0].delta.content
90+
if delta is not None:
91+
message += delta
92+
print(delta, end="")
93+
yield message
94+
print()
95+
96+
97+
def _make_response(bot: SimpleBot, messages: list[BaseMessage]):
98+
"""Make a response from the given messages.
99+
100+
:param bot: A SimpleBot
101+
:param messages: A list of Messages.
102+
:return: A response object.
103+
"""
104+
messages_dumped: list[dict] = [m.model_dump() for m in messages]
105+
completion_kwargs = dict(
106+
model=bot.model_name,
107+
messages=messages_dumped,
108+
temperature=bot.temperature,
109+
stream=bot.stream,
110+
)
111+
if bot.mock_response:
112+
completion_kwargs["mock_response"] = bot.mock_response
113+
if bot.json_mode:
114+
completion_kwargs["response_format"] = {"type": "json_object"}
115+
if bot.api_key:
116+
completion_kwargs["api_key"] = bot.api_key
117+
return completion(**completion_kwargs)
118+
92119
# Commented out until later.
93120
# def panel(
94121
# self,

llamabot/cli/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from llamabot import ChatBot, PromptRecorder
1010

11-
from . import blog, configure, doc, git, python, tutorial, zotero, repo
11+
from . import blog, configure, doc, git, python, tutorial, zotero, repo, serve
1212
from .utils import exit_if_asked, uniform_prompt
1313

1414
app = typer.Typer()
@@ -39,6 +39,9 @@
3939
)
4040
app.add_typer(configure.app, name="configure", help="Configure LlamaBot.")
4141
app.add_typer(repo.app, name="repo", help="Chat with a code repository.")
42+
app.add_typer(
43+
serve.cli, name="serve", help="Serve up a LlamaBot as a FastAPI endpoint."
44+
)
4245

4346

4447
@app.command()

llamabot/cli/git.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ def compose_commit():
106106
try:
107107
diff = get_git_diff()
108108
bot = commitbot()
109-
bot(write_commit_message(diff))
109+
msg = bot(write_commit_message(diff))
110+
echo(msg.content)
110111
except Exception as e:
111112
echo(f"Error encountered: {e}", err=True)
112113
echo("Please write your own commit message.", err=True)

0 commit comments

Comments
 (0)