@@ -114,7 +114,7 @@ def forward(self, vec: Tensor) -> tuple:
114
114
115
115
116
116
class DoubleStreamBlock (nn .Module ):
117
- def __init__ (self , hidden_size : int , num_heads : int , mlp_ratio : float , qkv_bias : bool = False , dtype = None , device = None , operations = None ):
117
+ def __init__ (self , hidden_size : int , num_heads : int , mlp_ratio : float , qkv_bias : bool = False , flipped_img_txt = False , dtype = None , device = None , operations = None ):
118
118
super ().__init__ ()
119
119
120
120
mlp_hidden_dim = int (hidden_size * mlp_ratio )
@@ -141,6 +141,7 @@ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias:
141
141
nn .GELU (approximate = "tanh" ),
142
142
operations .Linear (mlp_hidden_dim , hidden_size , bias = True , dtype = dtype , device = device ),
143
143
)
144
+ self .flipped_img_txt = flipped_img_txt
144
145
145
146
def forward (self , img : Tensor , txt : Tensor , vec : Tensor , pe : Tensor , attn_mask = None ):
146
147
img_mod1 , img_mod2 = self .img_mod (vec )
@@ -160,13 +161,22 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=N
160
161
txt_q , txt_k , txt_v = txt_qkv .view (txt_qkv .shape [0 ], txt_qkv .shape [1 ], 3 , self .num_heads , - 1 ).permute (2 , 0 , 3 , 1 , 4 )
161
162
txt_q , txt_k = self .txt_attn .norm (txt_q , txt_k , txt_v )
162
163
163
- # run actual attention
164
- attn = attention (torch .cat ((txt_q , img_q ), dim = 2 ),
165
- torch .cat ((txt_k , img_k ), dim = 2 ),
166
- torch .cat ((txt_v , img_v ), dim = 2 ),
167
- pe = pe , mask = attn_mask )
168
-
169
- txt_attn , img_attn = attn [:, : txt .shape [1 ]], attn [:, txt .shape [1 ] :]
164
+ if self .flipped_img_txt :
165
+ # run actual attention
166
+ attn = attention (torch .cat ((img_q , txt_q ), dim = 2 ),
167
+ torch .cat ((img_k , txt_k ), dim = 2 ),
168
+ torch .cat ((img_v , txt_v ), dim = 2 ),
169
+ pe = pe , mask = attn_mask )
170
+
171
+ img_attn , txt_attn = attn [:, : img .shape [1 ]], attn [:, img .shape [1 ]:]
172
+ else :
173
+ # run actual attention
174
+ attn = attention (torch .cat ((txt_q , img_q ), dim = 2 ),
175
+ torch .cat ((txt_k , img_k ), dim = 2 ),
176
+ torch .cat ((txt_v , img_v ), dim = 2 ),
177
+ pe = pe , mask = attn_mask )
178
+
179
+ txt_attn , img_attn = attn [:, : txt .shape [1 ]], attn [:, txt .shape [1 ]:]
170
180
171
181
# calculate the img bloks
172
182
img = img + img_mod1 .gate * self .img_attn .proj (img_attn )
0 commit comments