Skip to content

Commit 350713a

Browse files
committed
Merge branch 'dev'
2 parents 5cada8c + 41bcd48 commit 350713a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

110 files changed

+7126
-782
lines changed

README.md

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,6 @@
11

22
# <img src="doc/cat.png" width="40"> ExLlamaV3
33

4-
ExLlamaV3 is still in development. Please note: ↙
5-
6-
- The framework <u>is not yet fully optimized</u>. Performance is lacking, especially on Ampere, and there may be a significant CPU bottleneck on slower processors until the extension functions are fully built out.
7-
- AMD GPUs (ROCm) are not yet supported.
8-
- [FlashAttention-2](https://github.com/Dao-AILab/flash-attention) is currently required. I hope to switch over to [FlashInfer](https://github.com/flashinfer-ai/flashinfer/tree/main) in time, but there are some obstacles to overcome first.
9-
- A number of important features are yet to be added, such as tensor parallelism.
10-
114
## Why?
125

136
As the name implies, the original intention for ExLlama was to run inference on quantized Llama models. ExLlamaV2 was able to support a number of other architectures by treating every new model as (more or less) a Llama variant with optional features. However, as new models are increasingly moving away from the basic transformer template, this approach is no longer sustainable.
@@ -18,12 +11,13 @@ Aside from lifting a few of the most successful features from V2 (such as the ge
1811

1912
## What's missing?
2013

21-
There's much that still needs to be added and/or ported over from ExLlamaV2. I've decided to release ExLlamaV3 in its current state to invite testing, feedback and contributions, but please be aware that it's not yet a viable replacement for ExLlamaV2. Currently on the to-do list:
14+
Currently on the to-do list:
2215

16+
- Lots of optimization
2317
- LoRA support
2418
- ROCm support
25-
- Tensor-parallel inference
26-
- Lots of optimization
19+
- More sampling functions
20+
- More quantization modes (FP4 etc.)
2721

2822
As for what is implemented, expect that some things may be a little broken at first. Please be patient and/or contribute. 👉👈
2923

examples/chat.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ def main(args):
6262
temp_last = not args.temperature_first,
6363
)
6464

65+
# Single prompt mode
66+
single_prompt = args.prompt
67+
6568
# Main loop
6669
print("\n" + col_sysprompt + system_prompt.strip() + col_default)
6770
context = []
@@ -76,8 +79,18 @@ def main(args):
7679
context = []
7780

7881
# Get user prompt
79-
user_prompt = read_input_fn(args, user_name, multiline)
80-
prefix = ""
82+
if single_prompt is not None:
83+
# This round, use provided prompt from cmdline
84+
user_prompt = single_prompt
85+
prefix = ""
86+
# Next round, exit
87+
single_prompt = "/x"
88+
else:
89+
try:
90+
user_prompt = read_input_fn(args, user_name, multiline)
91+
prefix = ""
92+
except KeyboardInterrupt:
93+
user_prompt = "/x"
8194

8295
# Intercept commands
8396
if user_prompt.startswith("/"):
@@ -282,5 +295,6 @@ def get_input_ids(_prefix):
282295
parser.add_argument("-topk", "--top_k", type = int, help = "Top-K truncation, 0 to disable (default: disabled)", default = 0)
283296
parser.add_argument("-topp", "--top_p", type = float, help = "Top-P truncation, 1 to disable (default: disabled)", default = 1.0)
284297
parser.add_argument("-tps", "--show_tps", action = "store_true", help = "Show tokens/second after every reply")
298+
parser.add_argument("-prompt", "--prompt", type = str, help = "Run single prompt, then exit")
285299
_args = parser.parse_args()
286300
main(_args)

examples/multimodal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
model_dir = "/mnt/str/models/gemma3-4b-it/exl3/5.0bpw/"
1919
case "mistral3":
2020
prompt_format = "mistral"
21-
model_dir = "/mnt/str/models/mistral-small-3.1-24b-instruct/exl3/8.0bpw_H8"
21+
model_dir = "/mnt/str/models/mistral-small-3.1-24b-instruct-2503/exl3/4.0bpw/"
2222

2323
images = [
2424
# Cat

exllamav3/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from .models.config import Config
2-
from .models.model import Model
1+
from .model.config import Config
2+
from .model.model import Model
33
from .tokenizer import Tokenizer, MMEmbedding
44
from .cache import Cache, CacheLayer_fp16, CacheLayer_quant
55
from .generator import Generator, Job, AsyncGenerator, AsyncJob, Filter, FormatronFilter

exllamav3/architecture/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from __future__ import annotations

exllamav3/architecture/arcee.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
from typing_extensions import override
2+
import torch
3+
from ..modules import RMSNorm, Embedding, TransformerBlock, Attention, MLP, Linear
4+
from ..model.config import Config, no_default
5+
from ..model.model import Model
6+
from ..util.rope import RopeStyle
7+
from ..modules.attn import prepare_for_attn
8+
9+
class ArceeConfig(Config):
10+
arch_string = "ArceeForCausalLM"
11+
12+
def __init__(
13+
self,
14+
directory: str,
15+
derived_model: dict | None = None,
16+
**kwargs,
17+
):
18+
super().__init__(
19+
directory,
20+
derived_model if derived_model else {"text": ArceeModel},
21+
**kwargs
22+
)
23+
24+
# Attention params
25+
self.head_dim = self.read_cfg(int, "head_dim", None)
26+
self.hidden_size = self.read_cfg(int, "hidden_size", no_default)
27+
self.num_q_heads = self.read_cfg(int, "num_attention_heads", no_default)
28+
self.num_kv_heads = self.read_cfg(int, "num_key_value_heads", self.num_q_heads)
29+
30+
if not self.head_dim:
31+
self.head_dim = self.hidden_size // self.num_q_heads
32+
33+
# MLP params
34+
self.assert_cfg(str, "hidden_act", "relu2", True)
35+
self.intermediate_size = self.read_cfg(int, "intermediate_size", no_default)
36+
37+
# Norms
38+
self.rms_norm_eps = self.read_cfg(float, "rms_norm_eps", no_default)
39+
40+
# Layers
41+
self.num_hidden_layers = self.read_cfg(int, "num_hidden_layers", no_default)
42+
self.tie_word_embeddings = self.read_cfg(bool, "tie_word_embeddings", False)
43+
44+
# RoPE
45+
self.rope_settings = self.read_rope_settings_default(RopeStyle.NEOX)
46+
47+
48+
class ArceeModel(Model):
49+
config_class = ArceeConfig
50+
51+
def __init__(
52+
self,
53+
config: ArceeConfig,
54+
**kwargs
55+
):
56+
super().__init__(config, **kwargs)
57+
58+
self.modules += [
59+
Embedding(
60+
config = config,
61+
key = "model.embed_tokens",
62+
vocab_size = config.vocab_size,
63+
hidden_size = config.hidden_size,
64+
)
65+
]
66+
67+
self.first_block_idx = len(self.modules)
68+
69+
self.modules += [
70+
TransformerBlock(
71+
config = config,
72+
key = f"model.layers.{idx}",
73+
attn_norm = RMSNorm(
74+
config = config,
75+
key = f"model.layers.{idx}.input_layernorm",
76+
rms_norm_eps = config.rms_norm_eps,
77+
),
78+
attn = Attention(
79+
config = config,
80+
key = f"model.layers.{idx}.self_attn",
81+
layer_idx = idx,
82+
hidden_size = config.hidden_size,
83+
head_dim = config.head_dim,
84+
num_q_heads = config.num_q_heads,
85+
num_kv_heads = config.num_kv_heads,
86+
rope_settings = config.rope_settings,
87+
sm_scale = None,
88+
key_q = "q_proj",
89+
key_k = "k_proj",
90+
key_v = "v_proj",
91+
key_o = "o_proj",
92+
qmap = "block.attn",
93+
),
94+
mlp_norm = RMSNorm(
95+
config = config,
96+
key = f"model.layers.{idx}.post_attention_layernorm",
97+
rms_norm_eps = config.rms_norm_eps,
98+
),
99+
mlp = MLP(
100+
config = config,
101+
key = f"model.layers.{idx}.mlp",
102+
hidden_size = config.hidden_size,
103+
intermediate_size = config.intermediate_size,
104+
key_up = "up_proj",
105+
key_down = "down_proj",
106+
qmap = "block.mlp",
107+
activation_fn = "relu2",
108+
out_dtype = torch.float,
109+
),
110+
)
111+
for idx in range(config.num_hidden_layers)
112+
]
113+
114+
self.last_kv_module_idx = len(self.modules) - 1
115+
116+
head_alt_key = None
117+
if config.tie_word_embeddings and not self.config.stc.has_tensor("lm_head"):
118+
head_alt_key = "model.embed_tokens"
119+
120+
self.modules += [
121+
RMSNorm(
122+
config = config,
123+
key = "model.norm",
124+
rms_norm_eps = config.rms_norm_eps,
125+
out_dtype = torch.half,
126+
),
127+
Linear(
128+
config = config,
129+
key = "lm_head",
130+
qbits_key = "head_bits",
131+
alt_key = head_alt_key,
132+
in_features = config.hidden_size,
133+
out_features = config.vocab_size,
134+
qmap = "block",
135+
caps = {"logits_output": True}
136+
)
137+
]
138+
139+
self.logit_layer_idx = len(self.modules) - 1
140+
141+
@override
142+
def prepare_inputs(self, input_ids: torch.Tensor, params: dict) -> torch.Tensor:
143+
params["input_ids"] = input_ids
144+
input_ids = prepare_for_attn(input_ids, params)
145+
return input_ids
146+
147+
@override
148+
def default_chat_prompt(self, prompt: str, system_prompt: str | None = None) -> str:
149+
p = "<|begin_of_text|>"
150+
if system_prompt:
151+
p += f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
152+
p += f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
153+
return p

exllamav3/models/architectures.py renamed to exllamav3/architecture/architectures.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .arcee import ArceeModel
12
from .cohere import CohereModel
23
from .cohere2 import Cohere2Model
34
from .decilm import DeciLMModel
@@ -8,6 +9,7 @@
89
from .gemma2 import Gemma2Model
910
from .gemma3 import Gemma3Model, Gemma3TextModel
1011
from .glm4 import Glm4Model
12+
from .glm4_moe import Glm4MoeModel
1113
from .llama import LlamaModel
1214
from .mimo import MiMoModel
1315
from .mistral import MistralModel
@@ -25,6 +27,7 @@
2527
"config_class": m.config_class,
2628
"model_class": m,
2729
} for m in [
30+
ArceeModel,
2831
CohereModel,
2932
Cohere2Model,
3033
DeciLMModel,
@@ -36,6 +39,7 @@
3639
Gemma3Model,
3740
Gemma3TextModel,
3841
Glm4Model,
42+
Glm4MoeModel,
3943
LlamaModel,
4044
MiMoModel,
4145
MistralModel,
@@ -50,4 +54,4 @@
5054
}
5155

5256
def get_architectures():
53-
return ARCHITECTURES
57+
return ARCHITECTURES

exllamav3/models/cohere.py renamed to exllamav3/architecture/cohere.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from __future__ import annotations
22
from typing_extensions import override
33
import torch
4-
from .config import Config, no_default
5-
from .model import Model
6-
from ..util.rope import RopeSettings, RopeStyle
4+
from ..model.config import Config, no_default
5+
from ..model.model import Model
6+
from ..util.rope import RopeStyle
77
from ..modules import LayerNorm, Embedding, ParallelDecoderBlock, Attention, GatedMLP, Linear
88
from ..modules.attn import prepare_for_attn
99

@@ -152,7 +152,6 @@ def __init__(
152152

153153
@override
154154
def prepare_inputs(self, input_ids: torch.Tensor, params: dict) -> torch.Tensor:
155-
params["input_ids"] = input_ids
156155
input_ids = prepare_for_attn(input_ids, params)
157156
return input_ids
158157

exllamav3/models/cohere2.py renamed to exllamav3/architecture/cohere2.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from __future__ import annotations
22
from typing_extensions import override
33
import torch
4-
from .config import Config, no_default
5-
from .model import Model
6-
from ..util.rope import RopeSettings, RopeStyle
4+
from ..model.config import Config, no_default
5+
from ..model.model import Model
6+
from ..util.rope import RopeStyle
77
from ..modules import LayerNorm, Embedding, ParallelDecoderBlock, Attention, GatedMLP, Linear
88
from ..modules.attn import prepare_for_attn
99

@@ -151,7 +151,6 @@ def __init__(
151151

152152
@override
153153
def prepare_inputs(self, input_ids: torch.Tensor, params: dict) -> torch.Tensor:
154-
params["input_ids"] = input_ids
155154
input_ids = prepare_for_attn(input_ids, params)
156155
return input_ids
157156

exllamav3/models/decilm.py renamed to exllamav3/architecture/decilm.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from __future__ import annotations
22
from typing_extensions import override
33
import torch
4-
from .config import Config, no_default
5-
from .model import Model
6-
from ..util.rope import RopeSettings, RopeStyle
4+
from ..model.config import Config, no_default
5+
from ..model.model import Model
6+
from ..util.rope import RopeStyle
77
from ..modules import RMSNorm, Embedding, TransformerBlock, Attention, GatedMLP, Linear
88
from ..modules.attn import prepare_for_attn
99

@@ -172,7 +172,6 @@ def __init__(
172172

173173
@override
174174
def prepare_inputs(self, input_ids: torch.Tensor, params: dict) -> torch.Tensor:
175-
params["input_ids"] = input_ids
176175
input_ids = prepare_for_attn(input_ids, params)
177176
return input_ids
178177

0 commit comments

Comments
 (0)