-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathModelSizerv3.py
More file actions
254 lines (225 loc) · 9.27 KB
/
ModelSizerv3.py
File metadata and controls
254 lines (225 loc) · 9.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import argparse
from transformers import AutoConfig
# ---------------------------
# Args
# ---------------------------
parser = argparse.ArgumentParser(description="Model Sizer")
parser.add_argument("--model_repo", type=str, default="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8")
parser.add_argument("--kv_dtype", type=str, choices=["fp8", "fp16", "fp32"], default="fp8")
parser.add_argument("--context_window", type=int)
parser.add_argument("--trust_remote_code", action="store_true")
args = parser.parse_args()
# ---------------------------
# Config loader with dual-scope lookup
# ---------------------------
raw_cfg = AutoConfig.from_pretrained(args.model_repo, trust_remote_code=args.trust_remote_code)
text_cfg = getattr(raw_cfg, "text_config", None)
def get_attr(name, default=None):
if text_cfg is not None and hasattr(text_cfg, name):
return getattr(text_cfg, name)
return getattr(raw_cfg, name, default)
# ---------------------------
# Core getters
# ---------------------------
def hidden_size(): return int(get_attr("hidden_size", 0))
def num_layers(): return int(get_attr("num_hidden_layers", 0))
def num_heads(): return int(get_attr("num_attention_heads", 0))
def num_kv_heads():
v = (get_attr("num_key_value_heads", None)
or get_attr("num_kv_heads", None)
or get_attr("n_kv_heads", None)
or get_attr("n_head_kv", None))
if v is not None:
return int(v)
return 1 if bool(get_attr("multi_query", False)) else num_heads()
def head_dim():
v = (get_attr("head_dim", None)
or get_attr("attention_head_size", None)
or get_attr("dim_head", None)
or get_attr("headdim", None))
return int(v) if v is not None else (hidden_size() // num_heads() if num_heads() else 0)
def vocab_size(): return int(get_attr("vocab_size", 0))
def tie_word_embeddings(): return bool(get_attr("tie_word_embeddings", True))
def intermediate_size_local(): return int(get_attr("intermediate_size", 0))
def intermediate_size_shared(): return int(get_attr("intermediate_size_mlp", 0)) # shared expert (Llama‑4)
def num_local_experts(): return int(get_attr("num_local_experts", 0))
def num_experts_per_token(): return int(get_attr("num_experts_per_tok", get_attr("experts_per_token", 1)))
def layer_types():
lt = get_attr("layer_types", None)
return list(lt) if isinstance(lt, list) else None
def moe_layer_indices():
idx = get_attr("moe_layers", None)
if isinstance(idx, list) and all(isinstance(i, int) for i in idx):
return set(idx)
# fallback: interleave every k layers if provided
step = int(get_attr("interleave_moe_layer_step", 0) or 0)
if step > 0:
return set(range(0, num_layers(), step))
return set()
def context_window():
if args.context_window is not None:
return int(args.context_window)
return int(get_attr("max_position_embeddings",
get_attr("context_window",
get_attr("n_ctx", 0))))
# ---------------------------
# MLP style (Llama-family: gated/SwiGLU)
# ---------------------------
def is_gated_mlp():
act = (get_attr("hidden_act", "") or get_attr("activation_function", "") or "").lower()
if "glu" in act:
return True
model_type = (get_attr("model_type", "") or "").lower()
return model_type.startswith("llama")
def mlp_params(d_model, d_ff):
return (3 if is_gated_mlp() else 2) * d_model * d_ff
# ---------------------------
# Attention parameters per layer
# Q: d x (H*dk), K: d x (Hkv*dk), V: d x (Hkv*dk), O: d x (H*dk)
# ---------------------------
def attention_params_per_layer():
d = hidden_size()
H = num_heads()
Hkv = num_kv_heads()
dk = head_dim()
if not (d and H and Hkv and dk):
return 0
d_attn = H * dk
d_kv = Hkv * dk
q_and_o = 2 * d * d_attn
k_and_v = 2 * d * d_kv
return q_and_o + k_and_v
# ---------------------------
# Embeddings (untied doubles the cost)
# ---------------------------
def embedding_params_total():
V = vocab_size()
d = hidden_size()
if not (V and d):
return 0
base = V * d
return base if tie_word_embeddings() else 2 * base
# ---------------------------
# Parameter accounting
# - Total parameters: store/load footprint (all experts)
# - Active parameters: single-token forward (uses num_experts_per_token)
# ---------------------------
def total_and_active_parameter_counts():
d = hidden_size()
d_ff_local = intermediate_size_local()
d_ff_shared = intermediate_size_shared()
E = num_local_experts()
ept = num_experts_per_token()
L = num_layers()
moe_layers = moe_layer_indices()
attn_per_layer = attention_params_per_layer()
dense_mlp_per_layer = mlp_params(d, d_ff_local) if d_ff_local else 0
shared_mlp_per_layer = mlp_params(d, d_ff_shared) if d_ff_shared else 0
total_params = embedding_params_total()
active_params = embedding_params_total()
for layer_idx in range(L):
total_params += attn_per_layer
active_params += attn_per_layer
if layer_idx in moe_layers:
# shared expert once per MoE layer
total_params += shared_mlp_per_layer
active_params += shared_mlp_per_layer
# local experts: all for total, routed subset for active
total_params += E * dense_mlp_per_layer
active_params += ept * dense_mlp_per_layer
# router is tiny; include once per MoE layer
total_params += d * E
active_params += d * E
else:
# dense non‑MoE layer
total_params += dense_mlp_per_layer
active_params += dense_mlp_per_layer
return int(total_params), int(active_params)
# ---------------------------
# Bytes-per-parameter (weights)
# Prefer explicit quant config; fall back to dtype.
# ---------------------------
def bytes_per_parameter():
qcfg = getattr(raw_cfg, "quantization_config", {}) or {}
# Detect "float-quantized" 8-bit schemes
try:
fmt = (qcfg.get("format") or "").lower()
grp = (qcfg.get("config_groups") or {}).get("group_0", {})
w = (grp.get("weights") or {})
w_bits = w.get("num_bits", None)
w_type = (w.get("type") or "").lower() # "float" or "int"
if fmt == "float-quantized" and w_type == "float" and w_bits == 8:
return 1.0 # FP8-like storage
if fmt == "float-quantized" and w_type == "float" and w_bits == 4:
return 0.5 # FP4-like storage
except Exception:
pass
dtype = (get_attr("dtype", "") or "").lower()
mapping = {
"fp32": 4.0, "float32": 4.0, "bf32": 4.0, "bfloat32": 4.0, "torch.bfloat32": 4.0,
"fp16": 2.0, "float16": 2.0, "bf16": 2.0, "bfloat16": 2.0, "torch.bfloat16": 2.0,
"fp8": 1.0, "e4m3": 1.0, "e5m2": 1.0,
}
return mapping.get(dtype, 2.0) # default to BF16
# ---------------------------
# KV cache size
# ---------------------------
def bytes_per_kv_element():
return 1 if args.kv_dtype == "fp8" else 2 if args.kv_dtype == "fp16" else 4
def tokens_kept_in_layer(layer_type: str, tokens: int) -> int:
# Treat "sliding"/"local" as windowed if a window is provided.
# "chunked_attention" often chunks compute but keeps full KV; do not reduce.
lt = (layer_type or "").lower()
if "sliding" in lt or "local" in lt:
win = int(get_attr("sliding_window", 0) or 0)
return min(tokens, win) if win > 0 else tokens
return tokens
def kv_cache_size_bytes():
tokens = context_window()
L = num_layers()
Hkv = num_kv_heads()
dk = head_dim()
bpe = bytes_per_kv_element()
if not (tokens and L and Hkv and dk and bpe):
return 0
per_token_per_layer = 2 * Hkv * dk * bpe # K and V
lt = layer_types()
if isinstance(lt, list) and len(lt) == L:
total_tokens_across_layers = 0
for t in lt:
total_tokens_across_layers += tokens_kept_in_layer(str(t), tokens)
return per_token_per_layer * total_tokens_across_layers
return per_token_per_layer * tokens * L
# ---------------------------
# Units
# ---------------------------
def bytes_to_gib(n): return n / (1024 ** 3)
# ---------------------------
# Compute
# ---------------------------
total_params, active_params = total_and_active_parameter_counts()
bpp = bytes_per_parameter()
model_size_total_gib = bytes_to_gib(total_params * bpp)
model_size_active_gib = bytes_to_gib(active_params * bpp) # rarely used; total is what you must store
kv_bytes = kv_cache_size_bytes()
# ---------------------------
# Print
# ---------------------------
print(
f"Model: {args.model_repo}\n"
f"Context Window: {context_window()}\n"
f"Layers: {num_layers()}\n"
f"Attention Heads: {num_heads()}\n"
f"KV Heads: {num_kv_heads()}\n"
f"Head Dimension: {head_dim()}\n"
f"Hidden Size: {hidden_size()}\n"
f"Local Experts: {num_local_experts()}\n"
f"Experts per Token: {num_experts_per_token()}\n"
f"Tied Embeddings: {tie_word_embeddings()}\n"
f"Bytes/Weight Param: {bpp}\n"
f"Total Parameters (store/load): {total_params/1e9:.2f} Billion\n"
f"Active Parameters (per token): {active_params/1e9:.2f} Billion\n"
f"Estimated Model Size (Total): {model_size_total_gib:.2f} GiB\n"
f"Estimated Model Size (Active): {model_size_active_gib:.2f} GiB\n"
f"Estimated KV Cache Size: {bytes_to_gib(kv_bytes):.2f} GiB"
)