|
5 | 5 | from exllamav3 import Generator, Job, model_init
|
6 | 6 | from exllamav3.generator.sampler import ComboSampler
|
7 | 7 | from chat_templates import *
|
| 8 | +from chat_util import * |
8 | 9 | import torch
|
9 | 10 | from chat_console import *
|
10 | 11 |
|
@@ -61,15 +62,37 @@ def main(args):
|
61 | 62 | # Main loop
|
62 | 63 | print("\n" + col_sysprompt + system_prompt.strip() + col_default)
|
63 | 64 | context = []
|
| 65 | + response = "" |
64 | 66 |
|
65 | 67 | while True:
|
66 | 68 |
|
67 | 69 | # Amnesia mode
|
68 | 70 | if args.amnesia:
|
69 | 71 | context = []
|
70 | 72 |
|
71 |
| - # Get user prompt and add to context |
| 73 | + # Get user prompt |
72 | 74 | user_prompt = read_input_fn(args, user_name)
|
| 75 | + |
| 76 | + # Intercept commands |
| 77 | + if user_prompt.startswith("/"): |
| 78 | + c = user_prompt.strip() |
| 79 | + match c: |
| 80 | + case "/x": |
| 81 | + print_info("Exiting") |
| 82 | + break |
| 83 | + case "/cc": |
| 84 | + snippet = copy_last_codeblock(response) |
| 85 | + if not snippet: |
| 86 | + print_error("No code block found in last response") |
| 87 | + else: |
| 88 | + num_lines = len(snippet.split("\n")) |
| 89 | + print_info(f"Copied {num_lines} line{'s' if num_lines > 1 else ''} to the clipboard") |
| 90 | + continue |
| 91 | + case _: |
| 92 | + print_error(f"Unknown command: {c}") |
| 93 | + continue |
| 94 | + |
| 95 | + # Add to context |
73 | 96 | context.append((user_prompt, None))
|
74 | 97 |
|
75 | 98 | # Tokenize context and trim from head if too long
|
@@ -141,7 +164,7 @@ def get_input_ids():
|
141 | 164 | parser.add_argument("-freqp", "--frequency_penalty", type = float, help = "Frequency penalty, 0 to disable (default: disabled)", default = 0.0)
|
142 | 165 | parser.add_argument("-penr", "--penalty_range", type = int, help = "Range for penalties, in tokens (default: 1024) ", default = 1024)
|
143 | 166 | parser.add_argument("-minp", "--min_p", type = float, help = "Min-P truncation, 0 to disable (default: 0.08)", default = 0.08)
|
144 |
| - parser.add_argument("-topk", "--top_k", type = float, help = "Top-K truncation, 0 to disable (default: disabled)", default = 0) |
| 167 | + parser.add_argument("-topk", "--top_k", type = int, help = "Top-K truncation, 0 to disable (default: disabled)", default = 0) |
145 | 168 | parser.add_argument("-topp", "--top_p", type = float, help = "Top-P truncation, 1 to disable (default: disabled)", default = 1.0)
|
146 | 169 | _args = parser.parse_args()
|
147 | 170 | main(_args)
|
0 commit comments