Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
78829b6
Fixing some defaults if vocab not specified
petropusz Nov 30, 2020
2c71d2c
'Hierarchical variance segmentation' initial version, without torch f…
petropusz Nov 30, 2020
67890ad
option for merging until desired num of segments
petropusz Dec 1, 2020
5dc5500
part of torch segmentation layer
petropusz Dec 1, 2020
e035877
initial untested version of torch segmentation layer
petropusz Dec 1, 2020
400abc8
some fixes and testing to initial hierarchical variance clustering ve…
petropusz Dec 1, 2020
4a5afef
removed debug printout from parts that are to be used outside debuggi…
petropusz Dec 2, 2020
2bf242a
torch tensor device fixes in forward/backward + num of segments taken…
petropusz Dec 2, 2020
8f27475
using segmentation in wav2vec2_scribblelens as an option, but getting…
petropusz Dec 2, 2020
b8ded9c
added minimum number of segments per line so that masking at least 2 …
petropusz Dec 3, 2020
de752f3
SegmentDict fix
petropusz Dec 4, 2020
fdcc628
aaand removed printouts
petropusz Dec 4, 2020
4cc9f3b
initial basic segmentation image logging
petropusz Dec 15, 2020
17f91c8
better segmented image logging
petropusz Dec 15, 2020
b45de5d
segmentation logging fixes
petropusz Dec 16, 2020
e4d4a73
added option for sqare error cost instead of variance/mse to use as p…
petropusz Dec 17, 2020
47e30b1
option for logging representations (with input images also)
petropusz Dec 22, 2020
e895131
initial representation similarity plotting with some fixes
petropusz Dec 28, 2020
544d335
cosine distance option and preparation for different segment shorteni…
petropusz Jan 2, 2021
b6dc688
added options with segmentation with only averaging without shortenin…
petropusz Jan 4, 2021
0087112
fixing bugs from last 2 commits
petropusz Jan 4, 2021
91bf9a1
added rounding loss option to segmentation
petropusz Jan 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions fairseq/criterions/wav2vec_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@

@register_criterion("wav2vec")
class Wav2vecCriterion(FairseqCriterion):
def __init__(self, task, infonce=False, loss_weights=None, log_keys=None):
def __init__(self, task, infonce=False, loss_weights=None, log_keys=None, pass_metadata=False):
super().__init__(task)
self.infonce = infonce
self.loss_weights = None if loss_weights is None else eval(loss_weights)
self.log_keys = [] if log_keys is None else eval(log_keys)
self.pass_metadata = pass_metadata

@staticmethod
def add_args(parser):
Expand All @@ -30,6 +31,8 @@ def add_args(parser):
help='weights for additional loss terms (not first one)')
parser.add_argument('--log-keys', type=str, default=None,
help='output keys to log')
parser.add_argument('--pass-metadata', action='store_true',
help='if set, passes sample ids and epoch nr to the model (for model-specific logging of some specific-id examples per epoch etc.)')
# fmt: on

def forward(self, model, sample, reduce=True, log_pred=False):
Expand All @@ -40,7 +43,13 @@ def forward(self, model, sample, reduce=True, log_pred=False):
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
if self.pass_metadata:
# epoch is now also be passed in validation, but better be careful
net_output = model(**sample["net_input"], \
id=sample["id"], \
epoch=sample["epoch"].item() if "epoch" in sample else None)
else:
net_output = model(**sample["net_input"])
logits = model.get_logits(net_output).float()
target = model.get_targets(sample, net_output)

Expand Down
6 changes: 4 additions & 2 deletions fairseq/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,9 +469,11 @@ def arrange(s, e, length, keep_length):

mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))

min_len = min([len(m) for m in mask_idcs])
min_len = min([len(m) for m in mask_idcs])
# [!] input sequence outside padding has to have appropriate length for this to work correctly
# (e.g. length 1 with min_mask=2 can cause problems)
for i, mask_idc in enumerate(mask_idcs):
if len(mask_idc) > min_len:
if len(mask_idc) > min_len: # they want same number of masked stuff per line as a simplification
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
mask[i, mask_idc] = True

Expand Down
15 changes: 11 additions & 4 deletions fairseq/data/handwriting/alphabet.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ class Alphabet:
"""
def __init__(self, filename_=None, input_dict=None,
translation_dict={'_': ' '},
unk=("@",), blank=("*",), space=(' ', '_')):
unk=("@",), blank=("*",), space=(' ', '_'),
ensure_in_dict_on_no_vocab=None): # option for ensuring chars in case the vocab is not given

if filename_:
if filename_: # both None and '' will be 'False'
self.chars = bidict(self.readDictionary(filename_))
print('Alphabet constructed from', filename_,
'size=', len(self.chars))
Expand All @@ -73,12 +74,18 @@ def __init__(self, filename_=None, input_dict=None,
print('Alphabet constructed from dictionnary, '
'size=', len(self.chars))
else:
self.chars = bidict({
base_special_dict = {
k: i
for i, chs in enumerate([blank, space, unk])
for k in chs
})
}
if ensure_in_dict_on_no_vocab:
for c in ensure_in_dict_on_no_vocab:
if c not in base_special_dict:
base_special_dict[c] = len(base_special_dict)
self.chars = bidict(base_special_dict)
print('Alphabet constructed empty')

for c in unk:
if c not in self.chars:
print('Warning: UNK token', c, 'not in vocab')
Expand Down
4 changes: 2 additions & 2 deletions fairseq/data/handwriting/handwriting_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ def __init__(
): #extra_special_symbols=None,):

# [!] bos, pad, eos etc. need to be in dict file
super().__init__(alphabet_file, unk=(unk,))
super().__init__(alphabet_file, unk=(unk,), ensure_in_dict_on_no_vocab=(bos, pad, eos, unk))
#self._alphabet = Alphabet(alphabet_file, unk=(unk,))
for c, descr in zip((bos, pad, eos, unk), ("bos", "pad", "eos", "unk")):
if not self.existDict(c):
print('WARNING:', descr, 'token', c, 'not in vocab')
print('ERROR:', descr, 'token', c, 'not in vocab and vocab chosen, not constructed')
self.bos_char, self.unk_char, self.pad_char, self.eos_char = bos, unk, pad, eos
#self.symbols = []
#self.count = []
Expand Down
Loading