@@ -22,7 +22,8 @@ def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data
22
22
23
23
class SDXLTokenizer :
24
24
def __init__ (self , embedding_directory = None , tokenizer_data = {}):
25
- self .clip_l = sd1_clip .SDTokenizer (embedding_directory = embedding_directory )
25
+ clip_l_tokenizer_class = tokenizer_data .get ("clip_l_tokenizer_class" , sd1_clip .SDTokenizer )
26
+ self .clip_l = clip_l_tokenizer_class (embedding_directory = embedding_directory )
26
27
self .clip_g = SDXLClipGTokenizer (embedding_directory = embedding_directory )
27
28
28
29
def tokenize_with_weights (self , text :str , return_word_ids = False ):
@@ -40,7 +41,8 @@ def state_dict(self):
40
41
class SDXLClipModel (torch .nn .Module ):
41
42
def __init__ (self , device = "cpu" , dtype = None , model_options = {}):
42
43
super ().__init__ ()
43
- self .clip_l = sd1_clip .SDClipModel (layer = "hidden" , layer_idx = - 2 , device = device , dtype = dtype , layer_norm_hidden_state = False , model_options = model_options )
44
+ clip_l_class = model_options .get ("clip_l_class" , sd1_clip .SDClipModel )
45
+ self .clip_l = clip_l_class (layer = "hidden" , layer_idx = - 2 , device = device , dtype = dtype , layer_norm_hidden_state = False , model_options = model_options )
44
46
self .clip_g = SDXLClipG (device = device , dtype = dtype , model_options = model_options )
45
47
self .dtypes = set ([dtype ])
46
48
@@ -57,7 +59,8 @@ def encode_token_weights(self, token_weight_pairs):
57
59
token_weight_pairs_l = token_weight_pairs ["l" ]
58
60
g_out , g_pooled = self .clip_g .encode_token_weights (token_weight_pairs_g )
59
61
l_out , l_pooled = self .clip_l .encode_token_weights (token_weight_pairs_l )
60
- return torch .cat ([l_out , g_out ], dim = - 1 ), g_pooled
62
+ cut_to = min (l_out .shape [1 ], g_out .shape [1 ])
63
+ return torch .cat ([l_out [:,:cut_to ], g_out [:,:cut_to ]], dim = - 1 ), g_pooled
61
64
62
65
def load_sd (self , sd ):
63
66
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd :
0 commit comments