1
- from typing import Any , Callable , Optional , Type
1
+ from typing import Any , Callable
2
2
3
3
from prompt_toolkit import PromptSession
4
4
from prompt_toolkit .completion import WordCompleter
7
7
from rich .console import Console
8
8
from rich .markdown import Markdown
9
9
10
+ # from llm_repl.llms import BaseLLM
10
11
11
12
LLM_CMD_HANDLERS : dict [str , Callable ] = {}
12
13
@@ -34,7 +35,7 @@ def __init__(self, config: dict[str, Any]):
34
35
self .server_color = config ["style" ]["server" ]["color" ]
35
36
self .error_color = "bold red"
36
37
self .misc_color = "gray"
37
- self .model : Optional [BaseLLM ] = None
38
+ self .model = None # Optional[BaseLLM] = None
38
39
39
40
def handle_enter (self , event ):
40
41
"""
@@ -95,15 +96,15 @@ def print_error_msg(self, msg: str):
95
96
"""
96
97
self ._print_msg (self .ERROR_MSG_TITLE , msg , self .error_color )
97
98
98
- def print_misc_msg (self , msg : str ):
99
+ def print_misc_msg (self , msg : str , justify : str = "left" ):
99
100
"""
100
101
Print the miscellaneous message in the console.
101
102
102
103
:param str msg: The message to be printed.
103
104
"""
104
- self ._print_msg ("" , msg , self .misc_color , justify = "center" )
105
+ self ._print_msg ("" , msg , self .misc_color , justify = justify )
105
106
106
- def run (self , model : Type [ BaseLLM ]): # type: ignore
107
+ def run (self , model ):
107
108
"""
108
109
Starts the REPL.
109
110
@@ -115,11 +116,12 @@ def run(self, model: Type[BaseLLM]): # type: ignore
115
116
:param BaseLLM model: The LLM model to use.
116
117
"""
117
118
118
- self .model = model .load (self ) # type: ignore
119
+ self .model = model .load (self )
119
120
if self .model is None :
120
121
return
121
122
122
- self .print_misc_msg (self .INTRO_BANNER )
123
+ self .print_misc_msg (self .INTRO_BANNER , justify = "center" )
124
+ self .print_misc_msg (f"Loaded model: { self .model .name } " , justify = "center" )
123
125
124
126
while True :
125
127
user_input = self .session .prompt ("> " ).rstrip ()
@@ -140,6 +142,3 @@ def run(self, model: Type[BaseLLM]): # type: ignore
140
142
else :
141
143
self .console .print ()
142
144
self .console .rule (style = self .server_color )
143
-
144
-
145
- from llm_repl .llms import BaseLLM
0 commit comments