Skip to content

Commit 9be3325

Browse files
committed
Add ChatGPT-4 Model
Signed-off-by: Sebastiano Mariani <[email protected]>
1 parent ca11dbb commit 9be3325

File tree

6 files changed

+64
-26
lines changed

6 files changed

+64
-26
lines changed

pyproject.toml

+3
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,6 @@ ignored-modules = ""
4747
disable = """
4848
W1514,F0010,useless-super-delegation,E1103,W0108,W0404,R0904,R0922,W0105,
4949
W0142,C0301,C0321,C0322,C0324,R,W0232,E1001,W0212,W0703,C,I0011,I0012,I0013,E0012,W0511"""
50+
51+
[project.scripts]
52+
llm-repl = "llm_repl.__main__:main"

src/llm_repl/__main__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
}
1212
}
1313

14-
if __name__ == "__main__":
14+
15+
def main():
1516
parser = argparse.ArgumentParser(description="LLM REPL")
1617
parser.add_argument(
1718
"--model",

src/llm_repl/llms/__init__.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,13 @@
66
from langchain.callbacks.base import BaseCallbackHandler
77
from langchain.schema import AgentAction, AgentFinish, LLMResult
88

9-
from llm_repl.repl import LLMRepl
10-
119

1210
class BaseLLM(ABC):
11+
@property
12+
@abstractmethod
13+
def name(self) -> str:
14+
"""Return the name of the LLM."""
15+
1316
@property
1417
@abstractmethod
1518
def is_in_streaming_mode(self):
@@ -25,6 +28,9 @@ def process(self, msg) -> str:
2528
"""Process the user message and return the response."""
2629

2730

31+
from llm_repl.repl import LLMRepl
32+
33+
2834
class StreamingCallbackHandler(BaseCallbackHandler):
2935
"""Callback handler for streaming. Only works with LLMs that support streaming."""
3036

@@ -80,11 +86,11 @@ def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
8086
"""Run on agent end."""
8187

8288

83-
MODELS: Dict[str, Type[BaseLLM]] = {}
84-
89+
from llm_repl.llms.chatgpt import ChatGPT
90+
from llm_repl.llms.chatgpt4 import ChatGPT4
8591

86-
def register_model(model: Type[BaseLLM]):
87-
"""
88-
Register a model in the list of available models.
89-
"""
90-
MODELS[model.__name__.lower()] = model # type: ignore
92+
# TODO: Implement dynamic loading of models
93+
MODELS: Dict[str, Type[BaseLLM]] = {
94+
"chatgpt": ChatGPT,
95+
"chatgpt4": ChatGPT4,
96+
}

src/llm_repl/llms/chatgpt.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from rich.markdown import Markdown
1919

2020
from llm_repl.repl import LLMRepl
21-
from llm_repl.llms import BaseLLM, StreamingCallbackHandler, register_model
21+
from llm_repl.llms import BaseLLM, StreamingCallbackHandler # , register_model
2222

2323

2424
class ChatGPTStreamingCallbackHandler(StreamingCallbackHandler):
@@ -54,7 +54,7 @@ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
5454

5555

5656
class ChatGPT(BaseLLM):
57-
def __init__(self, api_key: str, repl: LLMRepl):
57+
def __init__(self, api_key: str, repl: LLMRepl, model_name: str = "gpt-3.5-turbo"):
5858
self.api_key = api_key
5959
# TODO: Make options configurable
6060
self.streaming_mode = True
@@ -76,11 +76,15 @@ def __init__(self, api_key: str, repl: LLMRepl):
7676
streaming=self.streaming_mode,
7777
callback_manager=CallbackManager([ChatGPTStreamingCallbackHandler(repl)]),
7878
verbose=True,
79-
# temperature=0,
79+
model_name=model_name,
8080
) # type: ignore
8181
memory = ConversationBufferMemory(return_messages=True)
8282
self.model = ConversationChain(memory=memory, prompt=prompt, llm=llm)
8383

84+
@property
85+
def name(self) -> str:
86+
return "ChatGPT"
87+
8488
@property
8589
def is_in_streaming_mode(self) -> bool:
8690
return self.streaming_mode
@@ -101,6 +105,3 @@ def load(cls, repl: LLMRepl) -> Optional[BaseLLM]:
101105
def process(self, msg: str) -> str:
102106
resp = self.model.predict(input=msg)
103107
return resp.strip()
104-
105-
106-
register_model(ChatGPT)

src/llm_repl/llms/chatgpt4.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from __future__ import annotations
2+
3+
import os
4+
5+
from typing import Optional
6+
7+
from llm_repl.repl import LLMRepl
8+
from llm_repl.llms import BaseLLM
9+
from llm_repl.llms.chatgpt import ChatGPT
10+
11+
12+
class ChatGPT4(ChatGPT):
13+
@property
14+
def name(self) -> str:
15+
return "ChatGPT-4"
16+
17+
@classmethod
18+
def load(cls, repl: LLMRepl) -> Optional[BaseLLM]:
19+
api_key = os.getenv("OPENAI_API_KEY")
20+
if api_key is None:
21+
repl.print_error_msg(
22+
"OpenAI API key not found. Please set the OPENAI_API_KEY environment variable."
23+
)
24+
return None
25+
26+
# TODO: Add autocomplete in repl
27+
model = cls(api_key, repl, model_name="gpt-4")
28+
return model

src/llm_repl/repl.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, Optional, Type
1+
from typing import Any, Callable
22

33
from prompt_toolkit import PromptSession
44
from prompt_toolkit.completion import WordCompleter
@@ -7,6 +7,7 @@
77
from rich.console import Console
88
from rich.markdown import Markdown
99

10+
# from llm_repl.llms import BaseLLM
1011

1112
LLM_CMD_HANDLERS: dict[str, Callable] = {}
1213

@@ -34,7 +35,7 @@ def __init__(self, config: dict[str, Any]):
3435
self.server_color = config["style"]["server"]["color"]
3536
self.error_color = "bold red"
3637
self.misc_color = "gray"
37-
self.model: Optional[BaseLLM] = None
38+
self.model = None # Optional[BaseLLM] = None
3839

3940
def handle_enter(self, event):
4041
"""
@@ -95,15 +96,15 @@ def print_error_msg(self, msg: str):
9596
"""
9697
self._print_msg(self.ERROR_MSG_TITLE, msg, self.error_color)
9798

98-
def print_misc_msg(self, msg: str):
99+
def print_misc_msg(self, msg: str, justify: str = "left"):
99100
"""
100101
Print the miscellaneous message in the console.
101102
102103
:param str msg: The message to be printed.
103104
"""
104-
self._print_msg("", msg, self.misc_color, justify="center")
105+
self._print_msg("", msg, self.misc_color, justify=justify)
105106

106-
def run(self, model: Type[BaseLLM]): # type: ignore
107+
def run(self, model):
107108
"""
108109
Starts the REPL.
109110
@@ -115,11 +116,12 @@ def run(self, model: Type[BaseLLM]): # type: ignore
115116
:param BaseLLM model: The LLM model to use.
116117
"""
117118

118-
self.model = model.load(self) # type: ignore
119+
self.model = model.load(self)
119120
if self.model is None:
120121
return
121122

122-
self.print_misc_msg(self.INTRO_BANNER)
123+
self.print_misc_msg(self.INTRO_BANNER, justify="center")
124+
self.print_misc_msg(f"Loaded model: {self.model.name}", justify="center")
123125

124126
while True:
125127
user_input = self.session.prompt("> ").rstrip()
@@ -140,6 +142,3 @@ def run(self, model: Type[BaseLLM]): # type: ignore
140142
else:
141143
self.console.print()
142144
self.console.rule(style=self.server_color)
143-
144-
145-
from llm_repl.llms import BaseLLM

0 commit comments

Comments
 (0)