Skip to content

Commit 83ca891

Browse files
Support scaled fp8 t5xxl model.
1 parent f9f9faf commit 83ca891

File tree

6 files changed

+63
-30
lines changed

6 files changed

+63
-30
lines changed

comfy/ops.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -290,12 +290,21 @@ def forward_comfy_cast_weights(self, input):
290290
weight, bias = cast_bias_weight(self, input)
291291
return torch.nn.functional.linear(input, weight, bias)
292292

293-
def scaled_fp8_ops(fp8_matrix_mult=False):
293+
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
294294
class scaled_fp8_op(manual_cast):
295295
class Linear(manual_cast.Linear):
296+
def __init__(self, *args, **kwargs):
297+
if override_dtype is not None:
298+
kwargs['dtype'] = override_dtype
299+
super().__init__(*args, **kwargs)
300+
296301
def reset_parameters(self):
297302
if not hasattr(self, 'scale_weight'):
298303
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
304+
305+
if not scale_input:
306+
self.scale_input = None
307+
299308
if not hasattr(self, 'scale_input'):
300309
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
301310
return None
@@ -328,7 +337,7 @@ def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
328337
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=False):
329338
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
330339
if scaled_fp8:
331-
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute)
340+
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True)
332341

333342
if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
334343
return fp8_ops

comfy/sd.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -432,16 +432,15 @@ def detect_te_model(sd):
432432
return None
433433

434434

435-
def t5xxl_weight_dtype(clip_data):
435+
def t5xxl_detect(clip_data):
436436
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
437437

438438
dtype_t5 = None
439439
for sd in clip_data:
440-
weight = sd.get(weight_name, None)
441-
if weight is not None:
442-
dtype_t5 = weight.dtype
443-
break
444-
return dtype_t5
440+
if weight_name in sd:
441+
return comfy.text_encoders.sd3_clip.t5_xxl_detect(sd)
442+
443+
return {}
445444

446445

447446
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@@ -475,7 +474,7 @@ class EmptyClass:
475474
clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel
476475
clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer
477476
elif te_model == TEModel.T5_XXL:
478-
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=t5xxl_weight_dtype(clip_data))
477+
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, **t5xxl_detect(clip_data))
479478
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
480479
elif te_model == TEModel.T5_XL:
481480
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
@@ -493,19 +492,19 @@ class EmptyClass:
493492
elif len(clip_data) == 2:
494493
if clip_type == CLIPType.SD3:
495494
te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])]
496-
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=TEModel.CLIP_L in te_models, clip_g=TEModel.CLIP_G in te_models, t5=TEModel.T5_XXL in te_models, dtype_t5=t5xxl_weight_dtype(clip_data))
495+
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=TEModel.CLIP_L in te_models, clip_g=TEModel.CLIP_G in te_models, t5=TEModel.T5_XXL in te_models, **t5xxl_detect(clip_data))
497496
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
498497
elif clip_type == CLIPType.HUNYUAN_DIT:
499498
clip_target.clip = comfy.text_encoders.hydit.HyditModel
500499
clip_target.tokenizer = comfy.text_encoders.hydit.HyditTokenizer
501500
elif clip_type == CLIPType.FLUX:
502-
clip_target.clip = comfy.text_encoders.flux.flux_clip(dtype_t5=t5xxl_weight_dtype(clip_data))
501+
clip_target.clip = comfy.text_encoders.flux.flux_clip(**t5xxl_detect(clip_data))
503502
clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer
504503
else:
505504
clip_target.clip = sdxl_clip.SDXLClipModel
506505
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
507506
elif len(clip_data) == 3:
508-
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(dtype_t5=t5xxl_weight_dtype(clip_data))
507+
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(**t5xxl_detect(clip_data))
509508
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
510509

511510
parameters = 0

comfy/sd1_clip.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,20 @@ def __init__(self, device="cpu", max_length=77,
9494
config = json.load(f)
9595

9696
operations = model_options.get("custom_operations", None)
97+
scaled_fp8 = None
98+
9799
if operations is None:
98-
operations = comfy.ops.manual_cast
100+
scaled_fp8 = model_options.get("scaled_fp8", None)
101+
if scaled_fp8 is not None:
102+
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
103+
else:
104+
operations = comfy.ops.manual_cast
99105

100106
self.operations = operations
101107
self.transformer = model_class(config, dtype, device, self.operations)
108+
if scaled_fp8 is not None:
109+
self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
110+
102111
self.num_layers = self.transformer.num_layers
103112

104113
self.max_length = max_length

comfy/supported_models.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -529,12 +529,11 @@ def clip_target(self, state_dict={}):
529529
clip_l = True
530530
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
531531
clip_g = True
532-
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
533-
if t5_key in state_dict:
532+
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
533+
if "dtype_t5" in t5_detect:
534534
t5 = True
535-
dtype_t5 = state_dict[t5_key].dtype
536535

537-
return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
536+
return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, **t5_detect))
538537

539538
class StableAudio(supported_models_base.BASE):
540539
unet_config = {
@@ -653,11 +652,8 @@ def get_model(self, state_dict, prefix="", device=None):
653652

654653
def clip_target(self, state_dict={}):
655654
pref = self.text_encoder_key_prefix[0]
656-
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
657-
dtype_t5 = None
658-
if t5_key in state_dict:
659-
dtype_t5 = state_dict[t5_key].dtype
660-
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5))
655+
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
656+
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
661657

662658
class FluxSchnell(Flux):
663659
unet_config = {

comfy/text_encoders/flux.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
11
from comfy import sd1_clip
22
import comfy.text_encoders.t5
3+
import comfy.text_encoders.sd3_clip
34
import comfy.model_management
45
from transformers import T5TokenizerFast
56
import torch
67
import os
78

8-
class T5XXLModel(sd1_clip.SDClipModel):
9-
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
10-
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
11-
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)
12-
139
class T5XXLTokenizer(sd1_clip.SDTokenizer):
1410
def __init__(self, embedding_directory=None, tokenizer_data={}):
1511
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
@@ -41,7 +37,7 @@ def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
4137
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
4238
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
4339
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
44-
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
40+
self.t5xxl = comfy.text_encoders.sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
4541
self.dtypes = set([dtype, dtype_t5])
4642

4743
def set_clip_options(self, options):
@@ -66,8 +62,11 @@ def load_sd(self, sd):
6662
else:
6763
return self.t5xxl.load_sd(sd)
6864

69-
def flux_clip(dtype_t5=None):
65+
def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
7066
class FluxClipModel_(FluxClipModel):
7167
def __init__(self, device="cpu", dtype=None, model_options={}):
68+
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
69+
model_options = model_options.copy()
70+
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
7271
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
7372
return FluxClipModel_

comfy/text_encoders/sd3_clip.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,26 @@
1010
class T5XXLModel(sd1_clip.SDClipModel):
1111
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}):
1212
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
13+
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
14+
if t5xxl_scaled_fp8 is not None:
15+
model_options = model_options.copy()
16+
model_options["scaled_fp8"] = t5xxl_scaled_fp8
17+
1318
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
1419

20+
21+
def t5_xxl_detect(state_dict, prefix=""):
22+
out = {}
23+
t5_key = "{}encoder.final_layer_norm.weight".format(prefix)
24+
if t5_key in state_dict:
25+
out["dtype_t5"] = state_dict[t5_key].dtype
26+
27+
scaled_fp8_key = "{}scaled_fp8".format(prefix)
28+
if scaled_fp8_key in state_dict:
29+
out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
30+
31+
return out
32+
1533
class T5XXLTokenizer(sd1_clip.SDTokenizer):
1634
def __init__(self, embedding_directory=None, tokenizer_data={}):
1735
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
@@ -139,8 +157,11 @@ def load_sd(self, sd):
139157
else:
140158
return self.t5xxl.load_sd(sd)
141159

142-
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False):
160+
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False):
143161
class SD3ClipModel_(SD3ClipModel):
144162
def __init__(self, device="cpu", dtype=None, model_options={}):
163+
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
164+
model_options = model_options.copy()
165+
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
145166
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
146167
return SD3ClipModel_

0 commit comments

Comments
 (0)