Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions fairseq/models/wav2vec/wav2vec2_scribblelens.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
LayerNorm,
MultiheadAttention,
SamePad,
Smartpool,
TransposeLast,
)
from fairseq.modules.transformer_sentence_encoder import init_bert_params
Expand Down Expand Up @@ -298,6 +299,24 @@ def add_args(parser):
"--conv-bias", action="store_true", help="include bias in conv encoder"
)

parser.add_argument(
"--smartpooling", action="store_true", help="whether to perform smartpooling"
)

parser.add_argument(
"--smartpooling-factor",
type=float,
default=3,
help="factor by which the sequence's length will be reduced in smartpooling"
)

parser.add_argument(
"--smartpooling-search-perc",
type=float,
default=0.3,
help="percentage of length of sequence after smartpooling to search for border. Ideally the border is located somewhere in +-search_perc"
)

def __init__(self, args):
super().__init__()
self.args = args
Expand All @@ -312,6 +331,7 @@ def __init__(self, args):
conv_bias=args.conv_bias,
)

self.smartpool = Smartpool(args.smartpooling_factor, args.smartpooling_search_perc) if args.smartpooling else None
self.post_extract_proj = (
nn.Linear(self.embed, args.encoder_embed_dim)
if self.embed != args.encoder_embed_dim and not args.quantize_input
Expand Down Expand Up @@ -541,8 +561,7 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False):

features = features.transpose(1, 2)
features = self.layer_norm(features)
unmasked_features = features.clone()


if padding_mask is not None:
assert padding_mask.size(1) == 1
padding_mask = padding_mask.squeeze(1)
Expand All @@ -552,6 +571,10 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False):
padding_mask = padding_mask[:, ::scale]
assert np.all(padding_mask.shape == features.shape[:-1])

if self.smartpool is not None:
features, padding_mask = self.smartpool(features, padding_mask)
unmasked_features = features.clone()

if self.post_extract_proj is not None:
features = self.post_extract_proj(features)

Expand Down
2 changes: 2 additions & 0 deletions fairseq/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .same_pad import SamePad
from .scalar_bias import ScalarBias
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
from .smartpool import Smartpool
from .transformer_sentence_encoder_layer import TransformerSentenceEncoderLayer
from .transformer_sentence_encoder import TransformerSentenceEncoder
from .transpose_last import TransposeLast
Expand Down Expand Up @@ -66,6 +67,7 @@
"SamePad",
"ScalarBias",
"SinusoidalPositionalEmbedding",
"Smartpool",
"TransformerSentenceEncoderLayer",
"TransformerSentenceEncoder",
"TransformerDecoderLayer",
Expand Down
117 changes: 117 additions & 0 deletions fairseq/modules/smartpool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

class Smartpool(nn.Module):
def __init__(
self,
factor,
search_perc
):
"""Smart pooling algorithm

Args:
factor: factor by which the sequence's length will be reduced
search_perc: percentage of length of sequence after smartpooling to search for border. Ideally the border is located somewhere in +-search_perc
"""
super().__init__()

self.search_perc = search_perc
self.factor = factor
self.register_buffer("filters", torch.FloatTensor([[[[-1,1],[1,-1]]]]), persistent=False)

def warp(self, X, new_lens):
new_lens_cs = new_lens.cumsum(1)
# This really searches for the low boundary of each new pixel
pixel_contributions = new_lens_cs.view(1, -1, 1) - torch.arange(torch.round(new_lens_cs[0, -1]).item(), device=X.device).view(1, 1, -1)
pixel_contributions = pixel_contributions.view(X.size(0), X.size(1), pixel_contributions.size(2))
# Zero out the negative contributions, i.e. pixels which come before each row
pixel_contributions = torch.max(torch.tensor(0.0, device=X.device), pixel_contributions)

# # This contains the cumulated pixel lengths for all pixels in each
# pixel_contributions

pixel_contributions = pixel_contributions.unsqueeze(1)
interp_weights = F.conv2d(pixel_contributions, self.filters, padding=1)
interp_weights = interp_weights[:,:,:-1,1:] # Removing padding
interp_weights = interp_weights.squeeze(1)

# # Each column corresponds to a new element. Its values are the
# # weights associated with the original data.
# interp_weights

interp_weights = interp_weights.transpose(1, 2)
Xnew = interp_weights @ X
return Xnew, interp_weights

def nonzero_interval_length(self, x, dim):
nonz = (x > 0)
_, low = ((nonz.cumsum(dim) == 1) & nonz).max(dim, keepdim=True)
rev_cumsum = nonz.long().flip(dim).cumsum(dim).flip(dim)
_, high = ((rev_cumsum == 1) & nonz).max(dim, keepdim=True)

return high - low + 1

def forward(self, features, padding_mask):
B,T,C = features.size()

padding_per_batch = (padding_mask > 0).sum(1)
total_T = padding_mask.numel() - padding_per_batch.sum()
features_together = torch.cat([features[i,:T-x] for i,x in enumerate(padding_per_batch)]).unsqueeze(0)

features_tmp = F.pad(features, (0,0,1,0), value=features_together.mean().item())
features_tmp = features_tmp.view(1, B * (T+1), C)

# We have to remove 1 front padding and X_i back paddings from each batch. X_i can be arbitrary
# but we have to append factors zeros so that there is one on the
# border between batches in resulting reduced sequence
# BATCH_1 000 BATCH_2 000 BATCH_3 -> REDUCED_1 0 REDUCED_2 0 REDUCED_3
new_lens = (features_tmp[:,1:,:] - features_tmp[:,:-1,:]).abs().sum(dim=2).squeeze(0)
new_lens = F.pad(new_lens, (1,0), value=0)
new_lens = torch.cat([torch.cat([new_lens[i*(T+1)+1:(i+1)*(T+1)-x], torch.zeros(3*int(self.factor), device=new_lens.device)]) for i,x in enumerate(padding_per_batch)]).unsqueeze(0)
new_lens = new_lens / new_lens.sum(1, keepdim=True) * ((total_T / self.factor) + B) # Reducing the original length T by some factor

features = torch.cat([torch.cat([features[i,:T-x], torch.zeros(3*int(self.factor), C, device=new_lens.device)]) for i,x in enumerate(padding_per_batch)]).unsqueeze(0)
features, interp_weights = self.warp(features, new_lens)

# The idea is to remove B-1 the longest spanning intervals
# which contain several zeros we added earlier

# Get the indices to remove
lengths_nonzero = self.nonzero_interval_length(interp_weights, 2)
theor_lengths = ((T - padding_per_batch) // int(self.factor) + 1).view(-1)
theor_cumsum = theor_lengths.cumsum(0)
theor_lengths = (theor_lengths.float() * self.search_perc).long()
to_remove = torch.cat(
[torch.argmax(
lengths_nonzero[:, theor_cumsum[i] - theor_lengths[i] : theor_cumsum[i] + theor_lengths[i], :]).view(1)
+ theor_cumsum[i] - theor_lengths[i] for i in range(0,B-1)])

indices = torch.arange(lengths_nonzero.size(1), device=lengths_nonzero.device)
to_remove = torch.cat([to_remove.view(-1), indices[-1].view(1)])

# Remove indices
mask = torch.ones_like(features, dtype=torch.bool, device=features.device).view(1, -1, C)
mask[0, to_remove, :] = False
features = features[mask].view(-1,C)

# Compute new features with padding
start_idx, _ = torch.sort(to_remove)
start_idx = start_idx - torch.arange(B, device=features.device)
start_idx = F.pad(start_idx, [1,0])
sizes = start_idx[1:] - start_idx[:-1]
new_T = torch.max(sizes)
sizes = new_T - sizes

features = torch.cat([torch.cat([features[start_idx[i-1]:start_idx[i]], torch.zeros(sizes[i-1], C, device=features.device)]) for i in range(1,B+1)])
features = features.view(B, new_T, C)

# Compute new mask padding mask
if padding_mask is not None:
padding_mask = torch.zeros(B, new_T, dtype=torch.bool, device=features.device)
for i,x in enumerate(sizes):
padding_mask[i, new_T-x:] = True

return features, padding_mask


Loading