Skip to content

Commit 2acbd29

Browse files
committed
Moved smartpooling to separate layer
1 parent 482aa3e commit 2acbd29

File tree

3 files changed

+123
-105
lines changed

3 files changed

+123
-105
lines changed

fairseq/models/wav2vec/wav2vec2_scribblelens.py

Lines changed: 4 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
LayerNorm,
2525
MultiheadAttention,
2626
SamePad,
27+
Smartpool,
2728
TransposeLast,
2829
)
2930
from fairseq.modules.transformer_sentence_encoder import init_bert_params
@@ -330,10 +331,7 @@ def __init__(self, args):
330331
conv_bias=args.conv_bias,
331332
)
332333

333-
self.smartpooling = args.smartpooling
334-
self.smartpooling_search_perc = args.smartpooling_search_perc
335-
self.smartpooling_factor = args.smartpooling_factor
336-
self.smartpooling_filters = torch.tensor([[[[-1,1],[1,-1]]]]).float()
334+
self.smartpool = Smartpool(args.smartpooling_factor, args.smartpooling_search_perc) if args.smartpooling else None
337335
self.post_extract_proj = (
338336
nn.Linear(self.embed, args.encoder_embed_dim)
339337
if self.embed != args.encoder_embed_dim and not args.quantize_input
@@ -425,11 +423,6 @@ def upgrade_state_dict_named(self, state_dict, name):
425423
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
426424
return state_dict
427425

428-
def to(self, *args, **kwargs):
429-
self = super().to(*args, **kwargs)
430-
self.smartpooling_filters = self.smartpooling_filters.to(*args, **kwargs)
431-
return self
432-
433426
@classmethod
434427
def build_model(cls, args, task=None):
435428
"""Build a new model instance."""
@@ -552,100 +545,6 @@ def compute_preds(self, x, y, negatives):
552545

553546
return logits
554547

555-
def smartpool(self, features, padding_mask=None):
556-
B,T,C = features.size()
557-
558-
padding_per_batch = (padding_mask > 0).sum(1)
559-
total_T = padding_mask.numel() - padding_per_batch.sum()
560-
features_together = torch.cat([features[i,:T-x] for i,x in enumerate(padding_per_batch)]).unsqueeze(0)
561-
562-
features_tmp = F.pad(features, (0,0,1,0), value=features_together.mean().item())
563-
features_tmp = features_tmp.view(1, B * (T+1), C)
564-
565-
# We have to remove 1 front padding and X_i back paddings from each batch. X_i can be arbitrary
566-
# but we have to append smartpooling_factors zeros so that there is one on the
567-
# border between batches in resulting reduced sequence
568-
# BATCH_1 000 BATCH_2 000 BATCH_3 -> REDUCED_1 0 REDUCED_2 0 REDUCED_3
569-
new_lens = (features_tmp[:,1:,:] - features_tmp[:,:-1,:]).abs().sum(dim=2).squeeze(0)
570-
new_lens = F.pad(new_lens, (1,0), value=0)
571-
new_lens = torch.cat([torch.cat([new_lens[i*(T+1)+1:(i+1)*(T+1)-x], torch.zeros(int(self.smartpooling_factor), device=new_lens.device)]) for i,x in enumerate(padding_per_batch)]).unsqueeze(0)
572-
new_lens = new_lens / new_lens.sum(1, keepdim=True) * ((total_T / self.smartpooling_factor) + B) # Reducing the original length T by some factor
573-
574-
features = torch.cat([torch.cat([features[i,:T-x], torch.zeros(int(self.smartpooling_factor), C, device=new_lens.device)]) for i,x in enumerate(padding_per_batch)]).unsqueeze(0)
575-
features, interp_weights = self.warp(features, new_lens)
576-
577-
# The idea is to remove B-1 the longest spanning intervals
578-
# which contain several zeros we added earlier
579-
def nonzero_interval_length(x, dim):
580-
nonz = (x > 0)
581-
_, low = ((nonz.cumsum(dim) == 1) & nonz).max(dim, keepdim=True)
582-
rev_cumsum = nonz.long().flip(dim).cumsum(dim).flip(dim)
583-
_, high = ((rev_cumsum == 1) & nonz).max(dim, keepdim=True)
584-
585-
return high - low + 1
586-
587-
# Get the indices to remove
588-
lengths_nonzero = nonzero_interval_length(interp_weights, 2)
589-
theor_lengths = ((T - padding_per_batch) // int(self.smartpooling_factor) + 1).view(-1)
590-
theor_cumsum = theor_lengths.cumsum(0)
591-
theor_lengths = (theor_lengths.float() * self.smartpooling_search_perc).long()
592-
to_remove = torch.cat(
593-
[torch.argmax(
594-
lengths_nonzero[:, theor_cumsum[i] - theor_lengths[i] : theor_cumsum[i] + theor_lengths[i], :]).view(1)
595-
+ theor_cumsum[i] - theor_lengths[i] for i in range(0,B-1)])
596-
597-
indices = buffered_arange(lengths_nonzero.size(1))
598-
indices = indices.to(lengths_nonzero.device)
599-
to_remove = torch.cat([to_remove.view(-1), indices[-1].view(1)])
600-
601-
# Remove indices
602-
mask = torch.ones_like(features, dtype=torch.bool, device=features.device).view(1, -1, C)
603-
mask[0, to_remove, :] = False
604-
features = features[mask].view(-1,C)
605-
606-
# Compute new features with padding
607-
start_idx, _ = torch.sort(to_remove)
608-
start_idx = start_idx - buffered_arange(B).to(features.device)
609-
start_idx = F.pad(start_idx, [1,0])
610-
sizes = start_idx[1:] - start_idx[:-1]
611-
new_T = torch.max(sizes)
612-
sizes = new_T - sizes
613-
614-
features = torch.cat([torch.cat([features[start_idx[i-1]:start_idx[i]], torch.zeros(sizes[i-1], C, device=features.device)]) for i in range(1,B+1)])
615-
features = features.view(B, new_T, C)
616-
617-
# Compute new mask padding mask
618-
if padding_mask is not None:
619-
padding_mask = torch.zeros(B, new_T, dtype=torch.bool, device=features.device)
620-
for i,x in enumerate(sizes):
621-
padding_mask[i, new_T-x:] = True
622-
623-
return features, padding_mask
624-
625-
def warp(self, X, new_lens):
626-
new_lens_cs = new_lens.cumsum(1)
627-
# This really searches for the low boundary of each new pixel
628-
pixel_contributions = new_lens_cs.view(1, -1, 1) - torch.arange(torch.round(new_lens_cs[0, -1]).item(), device=X.device).view(1, 1, -1)
629-
pixel_contributions = pixel_contributions.view(X.size(0), X.size(1), pixel_contributions.size(2))
630-
# Zero out the negative contributions, i.e. pixels which come before each row
631-
pixel_contributions = torch.max(torch.tensor(0.0, device=X.device), pixel_contributions)
632-
633-
# # This contains the cumulated pixel lengths for all pixels in each
634-
# pixel_contributions
635-
636-
pixel_contributions = pixel_contributions.unsqueeze(1)
637-
interp_weights = F.conv2d(pixel_contributions, self.smartpooling_filters, padding=1)
638-
interp_weights = interp_weights[:,:,:-1,1:] # Removing padding
639-
interp_weights = interp_weights.squeeze(1)
640-
641-
# # Each column corresponds to a new element. Its values are the
642-
# # weights associated with the original data.
643-
# interp_weights
644-
645-
interp_weights = interp_weights.transpose(1, 2)
646-
Xnew = interp_weights @ X
647-
return Xnew, interp_weights
648-
649548
def forward(self, source, padding_mask=None, mask=True, features_only=False):
650549
# padding_mask = None # JCh: padding_mask prob need to be True where the data is padded. mask=True => data invalid
651550

@@ -672,8 +571,8 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False):
672571
padding_mask = padding_mask[:, ::scale]
673572
assert np.all(padding_mask.shape == features.shape[:-1])
674573

675-
if self.smartpooling:
676-
features, padding_mask = self.smartpool(features, padding_mask=padding_mask)
574+
if self.smartpool is not None:
575+
features, padding_mask = self.smartpool(features, padding_mask)
677576
unmasked_features = features.clone()
678577

679578
if self.post_extract_proj is not None:

fairseq/modules/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from .same_pad import SamePad
3030
from .scalar_bias import ScalarBias
3131
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
32+
from .smartpool import Smartpool
3233
from .transformer_sentence_encoder_layer import TransformerSentenceEncoderLayer
3334
from .transformer_sentence_encoder import TransformerSentenceEncoder
3435
from .transpose_last import TransposeLast
@@ -66,6 +67,7 @@
6667
"SamePad",
6768
"ScalarBias",
6869
"SinusoidalPositionalEmbedding",
70+
"Smartpool",
6971
"TransformerSentenceEncoderLayer",
7072
"TransformerSentenceEncoder",
7173
"TransformerDecoderLayer",

fairseq/modules/smartpool.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
class Smartpool(nn.Module):
6+
def __init__(
7+
self,
8+
factor,
9+
search_perc
10+
):
11+
"""Smart pooling algorithm
12+
13+
Args:
14+
factor: factor by which the sequence's length will be reduced
15+
search_perc: percentage of length of sequence after smartpooling to search for border. Ideally the border is located somewhere in +-search_perc
16+
"""
17+
super().__init__()
18+
19+
self.search_perc = search_perc
20+
self.factor = factor
21+
self.register_buffer("filters", torch.FloatTensor([[[[-1,1],[1,-1]]]]), persistent=False)
22+
23+
def warp(self, X, new_lens):
24+
new_lens_cs = new_lens.cumsum(1)
25+
# This really searches for the low boundary of each new pixel
26+
pixel_contributions = new_lens_cs.view(1, -1, 1) - torch.arange(torch.round(new_lens_cs[0, -1]).item(), device=X.device).view(1, 1, -1)
27+
pixel_contributions = pixel_contributions.view(X.size(0), X.size(1), pixel_contributions.size(2))
28+
# Zero out the negative contributions, i.e. pixels which come before each row
29+
pixel_contributions = torch.max(torch.tensor(0.0, device=X.device), pixel_contributions)
30+
31+
# # This contains the cumulated pixel lengths for all pixels in each
32+
# pixel_contributions
33+
34+
pixel_contributions = pixel_contributions.unsqueeze(1)
35+
interp_weights = F.conv2d(pixel_contributions, self.filters, padding=1)
36+
interp_weights = interp_weights[:,:,:-1,1:] # Removing padding
37+
interp_weights = interp_weights.squeeze(1)
38+
39+
# # Each column corresponds to a new element. Its values are the
40+
# # weights associated with the original data.
41+
# interp_weights
42+
43+
interp_weights = interp_weights.transpose(1, 2)
44+
Xnew = interp_weights @ X
45+
return Xnew, interp_weights
46+
47+
def nonzero_interval_length(self, x, dim):
48+
nonz = (x > 0)
49+
_, low = ((nonz.cumsum(dim) == 1) & nonz).max(dim, keepdim=True)
50+
rev_cumsum = nonz.long().flip(dim).cumsum(dim).flip(dim)
51+
_, high = ((rev_cumsum == 1) & nonz).max(dim, keepdim=True)
52+
53+
return high - low + 1
54+
55+
def forward(self, features, padding_mask):
56+
B,T,C = features.size()
57+
58+
padding_per_batch = (padding_mask > 0).sum(1)
59+
total_T = padding_mask.numel() - padding_per_batch.sum()
60+
features_together = torch.cat([features[i,:T-x] for i,x in enumerate(padding_per_batch)]).unsqueeze(0)
61+
62+
features_tmp = F.pad(features, (0,0,1,0), value=features_together.mean().item())
63+
features_tmp = features_tmp.view(1, B * (T+1), C)
64+
65+
# We have to remove 1 front padding and X_i back paddings from each batch. X_i can be arbitrary
66+
# but we have to append factors zeros so that there is one on the
67+
# border between batches in resulting reduced sequence
68+
# BATCH_1 000 BATCH_2 000 BATCH_3 -> REDUCED_1 0 REDUCED_2 0 REDUCED_3
69+
new_lens = (features_tmp[:,1:,:] - features_tmp[:,:-1,:]).abs().sum(dim=2).squeeze(0)
70+
new_lens = F.pad(new_lens, (1,0), value=0)
71+
new_lens = torch.cat([torch.cat([new_lens[i*(T+1)+1:(i+1)*(T+1)-x], torch.zeros(3*int(self.factor), device=new_lens.device)]) for i,x in enumerate(padding_per_batch)]).unsqueeze(0)
72+
new_lens = new_lens / new_lens.sum(1, keepdim=True) * ((total_T / self.factor) + B) # Reducing the original length T by some factor
73+
74+
features = torch.cat([torch.cat([features[i,:T-x], torch.zeros(3*int(self.factor), C, device=new_lens.device)]) for i,x in enumerate(padding_per_batch)]).unsqueeze(0)
75+
features, interp_weights = self.warp(features, new_lens)
76+
77+
# The idea is to remove B-1 the longest spanning intervals
78+
# which contain several zeros we added earlier
79+
80+
# Get the indices to remove
81+
lengths_nonzero = self.nonzero_interval_length(interp_weights, 2)
82+
theor_lengths = ((T - padding_per_batch) // int(self.factor) + 1).view(-1)
83+
theor_cumsum = theor_lengths.cumsum(0)
84+
theor_lengths = (theor_lengths.float() * self.search_perc).long()
85+
to_remove = torch.cat(
86+
[torch.argmax(
87+
lengths_nonzero[:, theor_cumsum[i] - theor_lengths[i] : theor_cumsum[i] + theor_lengths[i], :]).view(1)
88+
+ theor_cumsum[i] - theor_lengths[i] for i in range(0,B-1)])
89+
90+
indices = torch.arange(lengths_nonzero.size(1), device=lengths_nonzero.device)
91+
to_remove = torch.cat([to_remove.view(-1), indices[-1].view(1)])
92+
93+
# Remove indices
94+
mask = torch.ones_like(features, dtype=torch.bool, device=features.device).view(1, -1, C)
95+
mask[0, to_remove, :] = False
96+
features = features[mask].view(-1,C)
97+
98+
# Compute new features with padding
99+
start_idx, _ = torch.sort(to_remove)
100+
start_idx = start_idx - torch.arange(B, device=features.device)
101+
start_idx = F.pad(start_idx, [1,0])
102+
sizes = start_idx[1:] - start_idx[:-1]
103+
new_T = torch.max(sizes)
104+
sizes = new_T - sizes
105+
106+
features = torch.cat([torch.cat([features[start_idx[i-1]:start_idx[i]], torch.zeros(sizes[i-1], C, device=features.device)]) for i in range(1,B+1)])
107+
features = features.view(B, new_T, C)
108+
109+
# Compute new mask padding mask
110+
if padding_mask is not None:
111+
padding_mask = torch.zeros(B, new_T, dtype=torch.bool, device=features.device)
112+
for i,x in enumerate(sizes):
113+
padding_mask[i, new_T-x:] = True
114+
115+
return features, padding_mask
116+
117+

0 commit comments

Comments
 (0)