24
24
LayerNorm ,
25
25
MultiheadAttention ,
26
26
SamePad ,
27
+ Smartpool ,
27
28
TransposeLast ,
28
29
)
29
30
from fairseq .modules .transformer_sentence_encoder import init_bert_params
@@ -330,10 +331,7 @@ def __init__(self, args):
330
331
conv_bias = args .conv_bias ,
331
332
)
332
333
333
- self .smartpooling = args .smartpooling
334
- self .smartpooling_search_perc = args .smartpooling_search_perc
335
- self .smartpooling_factor = args .smartpooling_factor
336
- self .smartpooling_filters = torch .tensor ([[[[- 1 ,1 ],[1 ,- 1 ]]]]).float ()
334
+ self .smartpool = Smartpool (args .smartpooling_factor , args .smartpooling_search_perc ) if args .smartpooling else None
337
335
self .post_extract_proj = (
338
336
nn .Linear (self .embed , args .encoder_embed_dim )
339
337
if self .embed != args .encoder_embed_dim and not args .quantize_input
@@ -425,11 +423,6 @@ def upgrade_state_dict_named(self, state_dict, name):
425
423
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
426
424
return state_dict
427
425
428
- def to (self , * args , ** kwargs ):
429
- self = super ().to (* args , ** kwargs )
430
- self .smartpooling_filters = self .smartpooling_filters .to (* args , ** kwargs )
431
- return self
432
-
433
426
@classmethod
434
427
def build_model (cls , args , task = None ):
435
428
"""Build a new model instance."""
@@ -552,100 +545,6 @@ def compute_preds(self, x, y, negatives):
552
545
553
546
return logits
554
547
555
- def smartpool (self , features , padding_mask = None ):
556
- B ,T ,C = features .size ()
557
-
558
- padding_per_batch = (padding_mask > 0 ).sum (1 )
559
- total_T = padding_mask .numel () - padding_per_batch .sum ()
560
- features_together = torch .cat ([features [i ,:T - x ] for i ,x in enumerate (padding_per_batch )]).unsqueeze (0 )
561
-
562
- features_tmp = F .pad (features , (0 ,0 ,1 ,0 ), value = features_together .mean ().item ())
563
- features_tmp = features_tmp .view (1 , B * (T + 1 ), C )
564
-
565
- # We have to remove 1 front padding and X_i back paddings from each batch. X_i can be arbitrary
566
- # but we have to append smartpooling_factors zeros so that there is one on the
567
- # border between batches in resulting reduced sequence
568
- # BATCH_1 000 BATCH_2 000 BATCH_3 -> REDUCED_1 0 REDUCED_2 0 REDUCED_3
569
- new_lens = (features_tmp [:,1 :,:] - features_tmp [:,:- 1 ,:]).abs ().sum (dim = 2 ).squeeze (0 )
570
- new_lens = F .pad (new_lens , (1 ,0 ), value = 0 )
571
- new_lens = torch .cat ([torch .cat ([new_lens [i * (T + 1 )+ 1 :(i + 1 )* (T + 1 )- x ], torch .zeros (int (self .smartpooling_factor ), device = new_lens .device )]) for i ,x in enumerate (padding_per_batch )]).unsqueeze (0 )
572
- new_lens = new_lens / new_lens .sum (1 , keepdim = True ) * ((total_T / self .smartpooling_factor ) + B ) # Reducing the original length T by some factor
573
-
574
- features = torch .cat ([torch .cat ([features [i ,:T - x ], torch .zeros (int (self .smartpooling_factor ), C , device = new_lens .device )]) for i ,x in enumerate (padding_per_batch )]).unsqueeze (0 )
575
- features , interp_weights = self .warp (features , new_lens )
576
-
577
- # The idea is to remove B-1 the longest spanning intervals
578
- # which contain several zeros we added earlier
579
- def nonzero_interval_length (x , dim ):
580
- nonz = (x > 0 )
581
- _ , low = ((nonz .cumsum (dim ) == 1 ) & nonz ).max (dim , keepdim = True )
582
- rev_cumsum = nonz .long ().flip (dim ).cumsum (dim ).flip (dim )
583
- _ , high = ((rev_cumsum == 1 ) & nonz ).max (dim , keepdim = True )
584
-
585
- return high - low + 1
586
-
587
- # Get the indices to remove
588
- lengths_nonzero = nonzero_interval_length (interp_weights , 2 )
589
- theor_lengths = ((T - padding_per_batch ) // int (self .smartpooling_factor ) + 1 ).view (- 1 )
590
- theor_cumsum = theor_lengths .cumsum (0 )
591
- theor_lengths = (theor_lengths .float () * self .smartpooling_search_perc ).long ()
592
- to_remove = torch .cat (
593
- [torch .argmax (
594
- lengths_nonzero [:, theor_cumsum [i ] - theor_lengths [i ] : theor_cumsum [i ] + theor_lengths [i ], :]).view (1 )
595
- + theor_cumsum [i ] - theor_lengths [i ] for i in range (0 ,B - 1 )])
596
-
597
- indices = buffered_arange (lengths_nonzero .size (1 ))
598
- indices = indices .to (lengths_nonzero .device )
599
- to_remove = torch .cat ([to_remove .view (- 1 ), indices [- 1 ].view (1 )])
600
-
601
- # Remove indices
602
- mask = torch .ones_like (features , dtype = torch .bool , device = features .device ).view (1 , - 1 , C )
603
- mask [0 , to_remove , :] = False
604
- features = features [mask ].view (- 1 ,C )
605
-
606
- # Compute new features with padding
607
- start_idx , _ = torch .sort (to_remove )
608
- start_idx = start_idx - buffered_arange (B ).to (features .device )
609
- start_idx = F .pad (start_idx , [1 ,0 ])
610
- sizes = start_idx [1 :] - start_idx [:- 1 ]
611
- new_T = torch .max (sizes )
612
- sizes = new_T - sizes
613
-
614
- features = torch .cat ([torch .cat ([features [start_idx [i - 1 ]:start_idx [i ]], torch .zeros (sizes [i - 1 ], C , device = features .device )]) for i in range (1 ,B + 1 )])
615
- features = features .view (B , new_T , C )
616
-
617
- # Compute new mask padding mask
618
- if padding_mask is not None :
619
- padding_mask = torch .zeros (B , new_T , dtype = torch .bool , device = features .device )
620
- for i ,x in enumerate (sizes ):
621
- padding_mask [i , new_T - x :] = True
622
-
623
- return features , padding_mask
624
-
625
- def warp (self , X , new_lens ):
626
- new_lens_cs = new_lens .cumsum (1 )
627
- # This really searches for the low boundary of each new pixel
628
- pixel_contributions = new_lens_cs .view (1 , - 1 , 1 ) - torch .arange (torch .round (new_lens_cs [0 , - 1 ]).item (), device = X .device ).view (1 , 1 , - 1 )
629
- pixel_contributions = pixel_contributions .view (X .size (0 ), X .size (1 ), pixel_contributions .size (2 ))
630
- # Zero out the negative contributions, i.e. pixels which come before each row
631
- pixel_contributions = torch .max (torch .tensor (0.0 , device = X .device ), pixel_contributions )
632
-
633
- # # This contains the cumulated pixel lengths for all pixels in each
634
- # pixel_contributions
635
-
636
- pixel_contributions = pixel_contributions .unsqueeze (1 )
637
- interp_weights = F .conv2d (pixel_contributions , self .smartpooling_filters , padding = 1 )
638
- interp_weights = interp_weights [:,:,:- 1 ,1 :] # Removing padding
639
- interp_weights = interp_weights .squeeze (1 )
640
-
641
- # # Each column corresponds to a new element. Its values are the
642
- # # weights associated with the original data.
643
- # interp_weights
644
-
645
- interp_weights = interp_weights .transpose (1 , 2 )
646
- Xnew = interp_weights @ X
647
- return Xnew , interp_weights
648
-
649
548
def forward (self , source , padding_mask = None , mask = True , features_only = False ):
650
549
# padding_mask = None # JCh: padding_mask prob need to be True where the data is padded. mask=True => data invalid
651
550
@@ -672,8 +571,8 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False):
672
571
padding_mask = padding_mask [:, ::scale ]
673
572
assert np .all (padding_mask .shape == features .shape [:- 1 ])
674
573
675
- if self .smartpooling :
676
- features , padding_mask = self .smartpool (features , padding_mask = padding_mask )
574
+ if self .smartpool is not None :
575
+ features , padding_mask = self .smartpool (features , padding_mask )
677
576
unmasked_features = features .clone ()
678
577
679
578
if self .post_extract_proj is not None :
0 commit comments