Skip to content

Commit a905cff

Browse files
authored
Merge pull request #37 from turboderp-org/dev
Merge Dev to master
2 parents 81a0a7d + 70056fe commit a905cff

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+990
-32
lines changed

eval/humaneval.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
" "
1515
),
1616
"granite": (
17-
"Question:\nComplete the following Python function:\n\n{{problem}}\n\nAnswer:\n"
17+
"<|endoftext|>Question:\nComplete the following Python function:\n\n{{problem}}\n\nAnswer:\n"
1818
"Sure! Here is how you might implement the function:\n\n```python\n{{problem}}",
1919
" "
2020
),
2121
"llama": (
22-
"[INST] <<SYS>>\n"
22+
"<s>[INST] <<SYS>>\n"
2323
"You are a helpful AI coding assistant.\n"
2424
"<</SYS>>\n\n"
2525
"Complete the following Python function:\n\n"
@@ -28,7 +28,7 @@
2828
" "
2929
),
3030
"llama3": (
31-
"<|start_header_id|>system<|end_header_id|>\n\n"
31+
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
3232
"You are a helpful AI coding assistant.<|eot_id|>"
3333
"<|start_header_id|>user<|end_header_id|>\n\n"
3434
"Complete the following Python function:\n\n{{problem}}<|eot_id|>"
@@ -37,7 +37,7 @@
3737
" "
3838
),
3939
"mistral": (
40-
"[INST] You are a helpful AI coding assistant.\n\n"
40+
"<s>[INST] You are a helpful AI coding assistant.\n\n"
4141
"Complete the following Python function:\n\n"
4242
"{{problem}}[/INST]"
4343
" Sure! Here is how you might implement the function:\n\n```python\n{{problem}}",
@@ -51,7 +51,7 @@
5151
" "
5252
),
5353
"reka": (
54-
"human: Complete the following Python function."
54+
"<|endoftext|>human: Complete the following Python function."
5555
" Provide your reasoning in comments, but be concise and don't second-guess."
5656
"\n\n{{problem}}"
5757
" <sep> assistant: ```python\n{{problem}}",
@@ -76,7 +76,7 @@
7676
" "
7777
),
7878
"deepseek": (
79-
"You are a helpful AI coding assistant.\n"
79+
"<|begin▁of▁sentence|>You are a helpful AI coding assistant.\n"
8080
"<|User|>Complete the following Python function:\n\n{{problem}}"
8181
"<|Assistant|>Sure! Here is how you might implement the function:\n\n```python\n{{problem}}",
8282
" "
@@ -124,7 +124,11 @@ def main(args):
124124
for idx, (problem_id, problem) in enumerate(problems.items()):
125125
b_problem = problem["prompt"]
126126
f_problem = prompt_format.replace("{{problem}}", b_problem)
127-
input_ids = tokenizer.encode(f_problem, encode_special_tokens = True, add_bos = True)
127+
input_ids = tokenizer.encode(
128+
f_problem,
129+
encode_special_tokens = True,
130+
add_bos = (args.prompt_format == "raw")
131+
)
128132
for s in range(num_samples_per_task):
129133
job = Job(
130134
input_ids = input_ids,

examples/chat.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from exllamav3 import Generator, Job, model_init
66
from exllamav3.generator.sampler import ComboSampler
77
from chat_templates import *
8+
from chat_util import *
89
import torch
910
from chat_console import *
1011

@@ -61,15 +62,37 @@ def main(args):
6162
# Main loop
6263
print("\n" + col_sysprompt + system_prompt.strip() + col_default)
6364
context = []
65+
response = ""
6466

6567
while True:
6668

6769
# Amnesia mode
6870
if args.amnesia:
6971
context = []
7072

71-
# Get user prompt and add to context
73+
# Get user prompt
7274
user_prompt = read_input_fn(args, user_name)
75+
76+
# Intercept commands
77+
if user_prompt.startswith("/"):
78+
c = user_prompt.strip()
79+
match c:
80+
case "/x":
81+
print_info("Exiting")
82+
break
83+
case "/cc":
84+
snippet = copy_last_codeblock(response)
85+
if not snippet:
86+
print_error("No code block found in last response")
87+
else:
88+
num_lines = len(snippet.split("\n"))
89+
print_info(f"Copied {num_lines} line{'s' if num_lines > 1 else ''} to the clipboard")
90+
continue
91+
case _:
92+
print_error(f"Unknown command: {c}")
93+
continue
94+
95+
# Add to context
7396
context.append((user_prompt, None))
7497

7598
# Tokenize context and trim from head if too long
@@ -141,7 +164,7 @@ def get_input_ids():
141164
parser.add_argument("-freqp", "--frequency_penalty", type = float, help = "Frequency penalty, 0 to disable (default: disabled)", default = 0.0)
142165
parser.add_argument("-penr", "--penalty_range", type = int, help = "Range for penalties, in tokens (default: 1024) ", default = 1024)
143166
parser.add_argument("-minp", "--min_p", type = float, help = "Min-P truncation, 0 to disable (default: 0.08)", default = 0.08)
144-
parser.add_argument("-topk", "--top_k", type = float, help = "Top-K truncation, 0 to disable (default: disabled)", default = 0)
167+
parser.add_argument("-topk", "--top_k", type = int, help = "Top-K truncation, 0 to disable (default: disabled)", default = 0)
145168
parser.add_argument("-topp", "--top_p", type = float, help = "Top-P truncation, 1 to disable (default: disabled)", default = 1.0)
146169
_args = parser.parse_args()
147170
main(_args)

examples/chat_console.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,15 @@
1515
col_think1 = "\u001b[35;1m" # Bright magenta
1616
col_think2 = "\u001b[35m" # Magenta
1717
col_error = "\u001b[31;1m" # Bright red
18+
col_info = "\u001b[32;1m" # Bright red
1819
col_sysprompt = "\u001b[37;1m" # Grey
1920

21+
def print_error(text):
22+
print(col_error + "\nError: " + col_default + text)
23+
24+
def print_info(text):
25+
print(col_info + "\nInfo: " + col_default + text)
26+
2027
def read_input_console(args, user_name):
2128
print("\n" + col_user + user_name + ": " + col_default, end = '', flush = True)
2229
if args.multiline:

examples/chat_util.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import re
2+
import sys
3+
import pyperclip
4+
5+
def copy_last_codeblock(text: str) -> str | None:
6+
pattern = re.compile(r"```[^\n`]*\n(.*?)```", re.DOTALL)
7+
matches = pattern.findall(text)
8+
if not matches:
9+
return None
10+
snippet = matches[-1].strip()
11+
pyperclip.copy(snippet)
12+
return snippet

exllamav3/conversion/allocation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def allocate_transformer(
4444
assert d
4545
if isinstance(g, list):
4646
for m in (g, u, d):
47-
key_ = m[0].key.replace(".slice.0", ".slice.*")
47+
key_ = m[0].key.replace(".slice.0", ".slice.*").replace(".experts.0.", ".experts.*.")
4848
keys += [key_]
4949
numels += [sum(mm.weights_numel() for mm in m)]
5050
for mm in m:
@@ -65,7 +65,7 @@ def allocate_transformer(
6565
assert d
6666
if isinstance(u, list):
6767
for m in (u, d):
68-
key_ = m[0].key.replace(".slice.0", ".slice.*")
68+
key_ = m[0].key.replace(".slice.0", ".slice.*").replace(".experts.0.", ".experts.*.")
6969
keys += [m]
7070
numels += [sum(mm.weights_numel() for mm in m)]
7171
for mm in m:

exllamav3/conversion/convert_model.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,9 @@ def main(args, job_state):
278278
qmaps = module.get_qmaps()
279279
if len(qmaps) > 0:
280280

281-
# Capture calibration input states during forward pass
281+
# Capture calibration input states during forward pass. For block-sparse models, all expert layers
282+
# are activated to ensure all down projections capture at least some calibration data. When the
283+
# state is advanced later, only selected experts will be used.
282284
with ProgressBar(f" -- Capturing: {module.key}" + slice_str, len(state)) as progress:
283285
capture_H = {}
284286
ref_states = []
@@ -287,12 +289,20 @@ def main(args, job_state):
287289
params = {
288290
"attn_mode": "flash_attn_nc",
289291
"capture": capture_H,
292+
"activate_all_experts": model.calibration_all_experts,
290293
}
291294
if slicing:
292295
params["q_mlp_slice"] = current_slice
293296
rs = module.prepare_for_device(state[i], params)
294297
rs = module.forward(rs, params)
295298
if i < num_ref_states:
299+
if model.calibration_all_experts:
300+
# Reference state for measuring error need, with only selected experts
301+
params = { "attn_mode": "flash_attn_nc" }
302+
if slicing:
303+
params["q_mlp_slice"] = current_slice
304+
rs = module.prepare_for_device(state[i], params)
305+
rs = module.forward(rs, params)
296306
ref_states.append(rs.cpu())
297307
rs = None
298308
print(f" -- Captured: {module.key}" + slice_str)

exllamav3/exllamav3_ext/bindings.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
5555
m.def("exl3_gemm", &exl3_gemm, "exl3_gemm");
5656
m.def("exl3_gemm_num_kernel_shapes", &exl3_gemm_num_kernel_shapes, "exl3_gemm_num_kernel_shapes");
5757
m.def("exl3_gemm_shape_compat", &exl3_gemm_shape_compat, "exl3_gemm_shape_compat");
58+
m.def("exl3_mgemm", &exl3_mgemm, "exl3_mgemm");
5859
m.def("hgemm", &hgemm, "hgemm");
5960
m.def("rope", &rope, "rope");
6061
m.def("silu_mul", &silu_mul, "silu_mul");

exllamav3/exllamav3_ext/generator/rep_pen.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ void apply_rep_pens_kernel
7575

7676
float w = v > 0.0f ? v / rep_p : v * rep_p;
7777
float f = factors[i] + 1e-30;
78-
float o = v * (1.0f - f) + w * f;
78+
float f1 = (1.0f - f) + 1e-30;
79+
float o = v * f1 + w * f;
7980
out_logits[i + range_min] = o;
8081
}
8182
}

exllamav3/exllamav3_ext/quant/comp_units/exl3_comp_unit_1.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,10 @@ fp_exl3_gemm_kernel tfp_exl3_gemm_kernel_fp16_b1[] = {
1717
EXL3_GEMM_KERNEL_INSTANCES(1, false)
1818
};
1919

20+
fp_exl3_mgemm_kernel tfp_exl3_mgemm_kernel_fp32_b1[] = {
21+
EXL3_MGEMM_KERNEL_INSTANCES(1, true)
22+
};
2023

24+
fp_exl3_mgemm_kernel tfp_exl3_mgemm_kernel_fp16_b1[] = {
25+
EXL3_MGEMM_KERNEL_INSTANCES(1, false)
26+
};

exllamav3/exllamav3_ext/quant/comp_units/exl3_comp_unit_1.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@
22

33
extern fp_exl3_gemm_kernel tfp_exl3_gemm_kernel_fp32_b1[];
44
extern fp_exl3_gemm_kernel tfp_exl3_gemm_kernel_fp16_b1[];
5+
extern fp_exl3_mgemm_kernel tfp_exl3_mgemm_kernel_fp32_b1[];
6+
extern fp_exl3_mgemm_kernel tfp_exl3_mgemm_kernel_fp16_b1[];

0 commit comments

Comments
 (0)