Skip to content

Commit e813abb

Browse files
Long CLIP L support for SDXL, SD3 and Flux.
Use the *CLIPLoader nodes.
1 parent 5e68a4c commit e813abb

File tree

6 files changed

+36
-17
lines changed

6 files changed

+36
-17
lines changed

comfy/sd.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -445,12 +445,8 @@ class EmptyClass:
445445
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
446446
else:
447447
w = clip_data[0].get("text_model.embeddings.position_embedding.weight", None)
448-
if w is not None and w.shape[0] == 248:
449-
clip_target.clip = comfy.text_encoders.long_clipl.LongClipModel
450-
clip_target.tokenizer = comfy.text_encoders.long_clipl.LongClipTokenizer
451-
else:
452-
clip_target.clip = sd1_clip.SD1ClipModel
453-
clip_target.tokenizer = sd1_clip.SD1Tokenizer
448+
clip_target.clip = sd1_clip.SD1ClipModel
449+
clip_target.tokenizer = sd1_clip.SD1Tokenizer
454450
elif len(clip_data) == 2:
455451
if clip_type == CLIPType.SD3:
456452
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
@@ -475,10 +471,12 @@ class EmptyClass:
475471
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
476472

477473
parameters = 0
474+
tokenizer_data = {}
478475
for c in clip_data:
479476
parameters += comfy.utils.calculate_parameters(c)
477+
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
480478

481-
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, model_options=model_options)
479+
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
482480
for c in clip_data:
483481
m, u = clip.load_sd(c)
484482
if len(m) > 0:

comfy/sd1_clip.py

+2
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,7 @@ class SD1Tokenizer:
542542
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
543543
self.clip_name = clip_name
544544
self.clip = "clip_{}".format(self.clip_name)
545+
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
545546
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
546547

547548
def tokenize_with_weights(self, text:str, return_word_ids=False):
@@ -570,6 +571,7 @@ def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", cl
570571
self.clip_name = clip_name
571572
self.clip = "clip_{}".format(self.clip_name)
572573

574+
clip_model = model_options.get("{}_class".format(self.clip), clip_model)
573575
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
574576

575577
self.dtypes = set()

comfy/sdxl_clip.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data
2222

2323
class SDXLTokenizer:
2424
def __init__(self, embedding_directory=None, tokenizer_data={}):
25-
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
25+
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
26+
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
2627
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
2728

2829
def tokenize_with_weights(self, text:str, return_word_ids=False):
@@ -40,7 +41,8 @@ def state_dict(self):
4041
class SDXLClipModel(torch.nn.Module):
4142
def __init__(self, device="cpu", dtype=None, model_options={}):
4243
super().__init__()
43-
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
44+
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
45+
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
4446
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
4547
self.dtypes = set([dtype])
4648

@@ -57,7 +59,8 @@ def encode_token_weights(self, token_weight_pairs):
5759
token_weight_pairs_l = token_weight_pairs["l"]
5860
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
5961
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
60-
return torch.cat([l_out, g_out], dim=-1), g_pooled
62+
cut_to = min(l_out.shape[1], g_out.shape[1])
63+
return torch.cat([l_out[:,:cut_to], g_out[:,:cut_to]], dim=-1), g_pooled
6164

6265
def load_sd(self, sd):
6366
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:

comfy/text_encoders/flux.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ def __init__(self, embedding_directory=None, tokenizer_data={}):
1818

1919
class FluxTokenizer:
2020
def __init__(self, embedding_directory=None, tokenizer_data={}):
21-
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
21+
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
22+
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
2223
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
2324

2425
def tokenize_with_weights(self, text:str, return_word_ids=False):
@@ -38,7 +39,8 @@ class FluxClipModel(torch.nn.Module):
3839
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
3940
super().__init__()
4041
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
41-
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
42+
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
43+
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
4244
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
4345
self.dtypes = set([dtype, dtype_t5])
4446

comfy/text_encoders/long_clipl.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ def __init__(self, embedding_directory=None, tokenizer_data={}):
66
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
77

88
class LongClipModel_(sd1_clip.SDClipModel):
9-
def __init__(self, device="cpu", dtype=None, model_options={}):
9+
def __init__(self, *args, **kwargs):
1010
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
11-
super().__init__(device=device, textmodel_json_config=textmodel_json_config, return_projected_pooled=False, dtype=dtype, model_options=model_options)
11+
super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs)
1212

1313
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
1414
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -17,3 +17,14 @@ def __init__(self, embedding_directory=None, tokenizer_data={}):
1717
class LongClipModel(sd1_clip.SD1ClipModel):
1818
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
1919
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
20+
21+
def model_options_long_clip(sd, tokenizer_data, model_options):
22+
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
23+
if w is None:
24+
w = sd.get("text_model.embeddings.position_embedding.weight", None)
25+
if w is not None and w.shape[0] == 248:
26+
tokenizer_data = tokenizer_data.copy()
27+
model_options = model_options.copy()
28+
tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_
29+
model_options["clip_l_class"] = LongClipModel_
30+
return tokenizer_data, model_options

comfy/text_encoders/sd3_clip.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ def __init__(self, embedding_directory=None, tokenizer_data={}):
2020

2121
class SD3Tokenizer:
2222
def __init__(self, embedding_directory=None, tokenizer_data={}):
23-
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
23+
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
24+
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
2425
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
2526
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
2627

@@ -42,7 +43,8 @@ def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu
4243
super().__init__()
4344
self.dtypes = set()
4445
if clip_l:
45-
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
46+
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
47+
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
4648
self.dtypes.add(dtype)
4749
else:
4850
self.clip_l = None
@@ -95,7 +97,8 @@ def encode_token_weights(self, token_weight_pairs):
9597
if self.clip_g is not None:
9698
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
9799
if lg_out is not None:
98-
lg_out = torch.cat([lg_out, g_out], dim=-1)
100+
cut_to = min(lg_out.shape[1], g_out.shape[1])
101+
lg_out = torch.cat([lg_out[:,:cut_to], g_out[:,:cut_to]], dim=-1)
99102
else:
100103
lg_out = torch.nn.functional.pad(g_out, (768, 0))
101104
else:

0 commit comments

Comments
 (0)