diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index 6ac7557dcc..dda994a63b 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -14,11 +14,12 @@ @register_criterion("wav2vec") class Wav2vecCriterion(FairseqCriterion): - def __init__(self, task, infonce=False, loss_weights=None, log_keys=None): + def __init__(self, task, infonce=False, loss_weights=None, log_keys=None, pass_metadata=False): super().__init__(task) self.infonce = infonce self.loss_weights = None if loss_weights is None else eval(loss_weights) self.log_keys = [] if log_keys is None else eval(log_keys) + self.pass_metadata = pass_metadata @staticmethod def add_args(parser): @@ -30,6 +31,8 @@ def add_args(parser): help='weights for additional loss terms (not first one)') parser.add_argument('--log-keys', type=str, default=None, help='output keys to log') + parser.add_argument('--pass-metadata', action='store_true', + help='if set, passes sample ids and epoch nr to the model (for model-specific logging of some specific-id examples per epoch etc.)') # fmt: on def forward(self, model, sample, reduce=True, log_pred=False): @@ -40,7 +43,13 @@ def forward(self, model, sample, reduce=True, log_pred=False): 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ - net_output = model(**sample["net_input"]) + if self.pass_metadata: + # epoch is now also be passed in validation, but better be careful + net_output = model(**sample["net_input"], \ + id=sample["id"], \ + epoch=sample["epoch"].item() if "epoch" in sample else None) + else: + net_output = model(**sample["net_input"]) logits = model.get_logits(net_output).float() target = model.get_targets(sample, net_output) diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 81f457365a..121155f373 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -469,9 +469,11 @@ def arrange(s, e, length, keep_length): mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) - min_len = min([len(m) for m in mask_idcs]) + min_len = min([len(m) for m in mask_idcs]) + # [!] input sequence outside padding has to have appropriate length for this to work correctly + # (e.g. length 1 with min_mask=2 can cause problems) for i, mask_idc in enumerate(mask_idcs): - if len(mask_idc) > min_len: + if len(mask_idc) > min_len: # they want same number of masked stuff per line as a simplification mask_idc = np.random.choice(mask_idc, min_len, replace=False) mask[i, mask_idc] = True diff --git a/fairseq/data/handwriting/alphabet.py b/fairseq/data/handwriting/alphabet.py index 460276cccb..befc75f0bb 100644 --- a/fairseq/data/handwriting/alphabet.py +++ b/fairseq/data/handwriting/alphabet.py @@ -62,9 +62,10 @@ class Alphabet: """ def __init__(self, filename_=None, input_dict=None, translation_dict={'_': ' '}, - unk=("@",), blank=("*",), space=(' ', '_')): + unk=("@",), blank=("*",), space=(' ', '_'), + ensure_in_dict_on_no_vocab=None): # option for ensuring chars in case the vocab is not given - if filename_: + if filename_: # both None and '' will be 'False' self.chars = bidict(self.readDictionary(filename_)) print('Alphabet constructed from', filename_, 'size=', len(self.chars)) @@ -73,12 +74,18 @@ def __init__(self, filename_=None, input_dict=None, print('Alphabet constructed from dictionnary, ' 'size=', len(self.chars)) else: - self.chars = bidict({ + base_special_dict = { k: i for i, chs in enumerate([blank, space, unk]) for k in chs - }) + } + if ensure_in_dict_on_no_vocab: + for c in ensure_in_dict_on_no_vocab: + if c not in base_special_dict: + base_special_dict[c] = len(base_special_dict) + self.chars = bidict(base_special_dict) print('Alphabet constructed empty') + for c in unk: if c not in self.chars: print('Warning: UNK token', c, 'not in vocab') diff --git a/fairseq/data/handwriting/handwriting_dictionary.py b/fairseq/data/handwriting/handwriting_dictionary.py index af23179434..21ace2f5a9 100644 --- a/fairseq/data/handwriting/handwriting_dictionary.py +++ b/fairseq/data/handwriting/handwriting_dictionary.py @@ -16,11 +16,11 @@ def __init__( ): #extra_special_symbols=None,): # [!] bos, pad, eos etc. need to be in dict file - super().__init__(alphabet_file, unk=(unk,)) + super().__init__(alphabet_file, unk=(unk,), ensure_in_dict_on_no_vocab=(bos, pad, eos, unk)) #self._alphabet = Alphabet(alphabet_file, unk=(unk,)) for c, descr in zip((bos, pad, eos, unk), ("bos", "pad", "eos", "unk")): if not self.existDict(c): - print('WARNING:', descr, 'token', c, 'not in vocab') + print('ERROR:', descr, 'token', c, 'not in vocab and vocab chosen, not constructed') self.bos_char, self.unk_char, self.pad_char, self.eos_char = bos, unk, pad, eos #self.symbols = [] #self.count = [] diff --git a/fairseq/models/wav2vec/wav2vec2_scribblelens.py b/fairseq/models/wav2vec/wav2vec2_scribblelens.py index 88c3798462..885610cc2c 100644 --- a/fairseq/models/wav2vec/wav2vec2_scribblelens.py +++ b/fairseq/models/wav2vec/wav2vec2_scribblelens.py @@ -25,10 +25,14 @@ MultiheadAttention, SamePad, TransposeLast, + HierarchicalSegmentationLayer, ) from fairseq.modules.transformer_sentence_encoder import init_bert_params from fairseq.utils import buffered_arange +import random +from PIL import Image, ImageDraw + @register_model("wav2vec2_scribblelens") class Wav2Vec2ModelSL(BaseFairseqModel): @staticmethod @@ -298,6 +302,39 @@ def add_args(parser): "--conv-bias", action="store_true", help="include bias in conv encoder" ) + parser.add_argument( + "--segm", type=str, help="use segmentation on representations; 'hier' (without ') for hierarchical segm; " \ + + "also contains options, e.g. for var format is hier::::, where:\n" \ + + " i) is se (squared error), var (variance, se div by length), cos (cosine similarity mapped linearly to distance metric and scaled with segment length) \n" \ + + " ii) is additional rounding loss to use (se, var, lin, cos, or none) - measuring distance of average given for segment from original representations; " \ + + "need to add weight for this loss in loss-weights param if not none\n" \ + + " iii) is one of: shorten (averages in segments and replaces each with length 1), orig_len (replace with mean in segments, but keep length), " \ + + "orig_len+guess_orig (as in orig_len, but use original not-averaged representations as masked ones to guess correct one from)\n" \ + + " iv) is/are float/floats of format or -\n" + ) # TODO maybe also think about an option with ~constant length reduction (but at least one piece each segment) so that + # the long segments are not complete random averaged stuff + + parser.add_argument( + "--log-ids", type=str, help="for what ids to log, format: :arg1,:arg1:arg2,... without spaces, " \ + + "operator can be =(id) [=:id] or %(X, id) [%:1000:0], meaning exact id or ids of that modulo X" + ) + + parser.add_argument( + "--random-log-freq", type=float, help="how frequently (pbb) to log for randomly chosen ids" + ) + + # to have some data logged, need to specify for which IDs (--log-ids and/or --random-log-freq) and what to log (flags below) + + parser.add_argument( + "--segm-log-dir", type=str, help="where to log chosen segmentation images; also serves as 'do log' flag" + ) + + parser.add_argument( + "--repr-data-log-dir", type=str, help="where to log chosen array data (representation data, raw input images, and segment borders if segmentation); also serves as 'do log' flag" + ) + + + def __init__(self, args): super().__init__() self.args = args @@ -398,6 +435,65 @@ def __init__(self, args): self.final_proj = nn.Linear(args.encoder_embed_dim, final_dim) + # options for choosing ids to log for + self.random_log_freq = args.random_log_freq if 'random_log_freq' in args else None + if 'log_ids' in args: + options = args.log_ids.split(",") + self.log_ids = [] + for opt in options: + details = opt.split(':') + if details[0] == "%": + # need to bind details like that, because otherwise details variable will be bound to some random thing + # that was used later and was also named details; + # imo one of the nastiest things in python, + # as scope in python is "until end of function" and not "until the end of the function or sth" + self.log_ids.append((lambda details: (lambda x: x % int(details[1]) == int(details[2])))(details)) + elif details[0] == '=': + self.log_ids.append((lambda details: (lambda x: x == int(details[1])))(details)) + else: + assert False + else: + self.log_ids = None + + # part for supported segmentation options + if 'segm' in args: + segm_opts = args.segm.split(":") + # this part needs to set stuff needed by 'segmentation' method + if segm_opts[0] == "hier": + self.segm = "var" + assert len(segm_opts) == 5 + self.hier_segm_merge_priority = segm_opts[1] + self.hier_segm_rounding_loss = segm_opts[2] if segm_opts[2] != "none" else None + shorten_opts = segm_opts[3].split("+") + self.hier_segm_shortening_policy = shorten_opts[0] + self.hier_segm_guess_orig = len(shorten_opts) > 1 and shorten_opts[1] == "guess_orig" + length_reduction_options = list(map(float, segm_opts[4].split("-"))) + if len(length_reduction_options) == 1: + self.hier_segm_strict_reduction = length_reduction_options[0] + self.hier_segm_reduction_range = None + elif len(length_reduction_options) == 2: + self.hier_segm_strict_reduction = None + assert length_reduction_options[0] <= length_reduction_options[1] + self.hier_segm_reduction_range = tuple(length_reduction_options) + else: + assert False + else: + assert False # for now only that supported + if 'segm_log_dir' in args: + self.segm_log_dir = args.segm_log_dir + else: + self.segm_log_dir = None + else: + self.segm = None + self.segm_log_dir = None + + if 'repr_data_log_dir' in args: + self.repr_data_log_dir = args.repr_data_log_dir + else: + self.repr_data_log_dir = None + + self.need_logging = self.segm_log_dir is not None or self.repr_data_log_dir is not None + def upgrade_state_dict_named(self, state_dict, name): super().upgrade_state_dict_named(state_dict, name) """Upgrade a (possibly old) state dict for new versions of fairseq.""" @@ -525,7 +621,7 @@ def compute_preds(self, x, y, negatives): return logits - def forward(self, source, padding_mask=None, mask=True, features_only=False): + def forward(self, source, padding_mask=None, mask=True, features_only=False, id=None, epoch=None): # padding_mask = None # JCh: padding_mask prob need to be True where the data is padded. mask=True => data invalid if self.feature_grad_mult > 0: @@ -541,17 +637,55 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False): features = features.transpose(1, 2) features = self.layer_norm(features) - unmasked_features = features.clone() + # unmasked_features = features.clone() needed to move after segmentation if padding_mask is not None: assert padding_mask.size(1) == 1 padding_mask = padding_mask.squeeze(1) + scale_float = float(padding_mask.size(1)) / features.size(1) scale = padding_mask.size(1) // features.size(1) extra = padding_mask.size(1) % features.size(1) # should be 0 since 1st CNN reduces number of features [scale] times (due to the architecture choice) assert extra == 0 padding_mask = padding_mask[:, ::scale] assert np.all(padding_mask.shape == features.shape[:-1]) - + else: + scale_float = float(source.size(1)) / features.size(1) + + # TODO maybe move logging segments images to a script also and here just log borders array? or option both here and as a script potentially + if self.need_logging: + for i in range(source.shape[0]): + if self.check_if_log_for_id(id=id[i].item()): + if self.repr_data_log_dir: + self.log_repr_nonsegmentation_data(source[i], features[i], id=id[i].item() if id is not None else None, epoch=epoch) + # [!] logging here, before projection as this is what representation segmentation uses + # - TODO would otherwise need to change unmasked_features = features.clone() to be after projection instead of before + + if self.segm: + if self.hier_segm_guess_orig: + unmasked_features = features.clone() # if guessing original features, get them before averaging + features, padding_mask, segment_borders, rounding_loss = self.segmentation(features, padding_mask, 5) + # [!] minSegmsPerLine needs to be at least a few so that part with masking with at least 2 masks works correctly + + if self.need_logging: + for i in range(source.shape[0]): + if self.check_if_log_for_id(id=id[i].item()): + if self.segm_log_dir: + assert self.segm + # changed segment lines to only log begins (1 is there now for every segment, -1 if length > 1) + # as can just mult by scale for begins, for ends would need to also add scale - 1 + self.log_named_segmented_image(source[i], [int(round(j*scale_float)) for j, k in enumerate(segment_borders[i]) if k.item() == 1], id=id[i].item() if id is not None else None, epoch=epoch) + if self.repr_data_log_dir and self.segm: + self.log_repr_segmentation_data(segment_borders[i], id=id[i].item() if id is not None else None, epoch=epoch) + # [!] logging here, before projection as this is what representation segmentation uses + # - TODO would otherwise need to change unmasked_features = features.clone() to be after projection instead of before + + if not self.segm or not self.hier_segm_guess_orig: + unmasked_features = features.clone() + + assert(unmasked_features is not None) + + # doing it here as needed to clone features after segmentation and clone was before post_extract_proj + # - [!] TODO maybe check if this (cloning before post_extract_proj) is intended, but perhaps not a very big difference (only linear projection) if self.post_extract_proj is not None: features = self.post_extract_proj(features) @@ -645,9 +779,70 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False): result["code_perplexity"] = code_ppl result["num_vars"] = num_vars result["temp"] = curr_temp + if self.segm and self.hier_segm_rounding_loss is not None: + result["rounding_loss"] = rounding_loss return result + def segmentation(self, features, padding_mask, minSegmsPerLine): + assert self.segm == 'var' # for now only that supported, to be extended + non_padded = padding_mask.numel() - padding_mask.sum().item() + if self.hier_segm_strict_reduction is not None: + base_len_sum = int(round(non_padded / self.hier_segm_strict_reduction)) + return HierarchicalSegmentationLayer.apply(features, padding_mask, base_len_sum, None, minSegmsPerLine, self.hier_segm_merge_priority, self.hier_segm_shortening_policy, self.hier_segm_rounding_loss) + else: + min_reduction, max_reduction = self.hier_segm_reduction_range + min_segm = base_len_sum = int(round(non_padded / max_reduction)) #max(features.shape[0], int(round(0.85*base_len_sum))) + max_segm = base_len_sum = int(round(non_padded / min_reduction)) #min(non_padded, int(round(1.15*base_len_sum))) + return HierarchicalSegmentationLayer.apply(features, padding_mask, None, (min_segm, max_segm), minSegmsPerLine, self.hier_segm_merge_priority, self.hier_segm_shortening_policy, self.hier_segm_rounding_loss) + + def log_segmented_image(self, img, borders, name=None, convert_numbers_from_01=True): + converted_grayscale_img = img*255. if convert_numbers_from_01 else img + if torch.is_tensor(converted_grayscale_img): + converted_grayscale_img = converted_grayscale_img.detach().cpu() + img = Image.fromarray(np.array(converted_grayscale_img, dtype=np.int32)).convert('RGB') + draw = ImageDraw.Draw(img) + for border in borders: + #if borders[i] != 0: + #print("!", source[0].shape, i*scale_float) + draw.line([(border, 0), (border, 31)], fill='red', width=2) + save_name = name if name is not None else "" + img.save(self.segm_log_dir + "/" + save_name + ".png") + + def log_named_segmented_image(self, img, borders, id=None, epoch=None): + name = "segm_id_" + str(id) + "_epoch_" + str(epoch) if id is not None else None # will have names with id, possibly overwriting each epoch, otherwise random ids + self.log_segmented_image(img, borders, name=name, convert_numbers_from_01=True) + + def log_repr_nonsegmentation_data(self, img, features, id=None, epoch=None): + if torch.is_tensor(img): + img = img.detach().cpu() + if torch.is_tensor(features): + features = features.detach().cpu() + img_np = np.array(img) + features_np = np.array(features) + img_name = "input_id_" + str(id) + "_epoch_" + str(epoch) if id is not None else None # will have names with id, possibly overwriting each epoch, otherwise random ids + features_name = "features_id_" + str(id) + "_epoch_" + str(epoch) if id is not None else None # will have names with id, possibly overwriting each epoch, otherwise random ids + np.save(self.repr_data_log_dir + "/" + img_name, img_np) + np.save(self.repr_data_log_dir + "/" + features_name, features_np) + + def log_repr_segmentation_data(self, borders, id=None, epoch=None): + if torch.is_tensor(borders): + borders = borders.detach().cpu() + borders_np = np.array(borders) + borders_name = "segmentborders_id_" + str(id) + "_epoch_" + str(epoch) if id is not None else None # will have names with id, possibly overwriting each epoch, otherwise random ids + np.save(self.repr_data_log_dir + "/" + borders_name, borders_np) + + def check_if_log_for_id(self, id=None): + if self.random_log_freq is not None: + if random.random() < self.random_log_freq: + return True + if self.log_ids is not None: + assert id is not None # need to use pass-metadata arg in criterion (if wav2vec, if other need to add this option) + for log_rule in self.log_ids: + if log_rule(id): # check if fits + return True + return False + def quantize(self, x): assert self.quantizer is not None x = self.feature_extractor(x) @@ -681,6 +876,9 @@ def get_extra_losses(self, net_output): if "features_pen" in net_output: pen.append(net_output["features_pen"]) + if self.segm and self.hier_segm_rounding_loss is not None: + pen.append(net_output["rounding_loss"]) + return pen def remove_pretraining_modules(self): diff --git a/fairseq/modules/__init__.py b/fairseq/modules/__init__.py index e2326ac6e3..fbcfb4ead3 100644 --- a/fairseq/modules/__init__.py +++ b/fairseq/modules/__init__.py @@ -35,6 +35,7 @@ from .unfold import unfold1d from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer from .vggblock import VGGBlock +from .segmentation.hierarchical_segmentation import hierarchicalSegmentation, HierarchicalSegmentationLayer __all__ = [ "AdaptiveInput", @@ -72,5 +73,7 @@ "TransformerEncoderLayer", "TransposeLast", "VGGBlock", + "hierarchicalSegmentation", + "HierarchicalSegmentationLayer", "unfold1d", ] diff --git a/fairseq/modules/segmentation/__init__.py b/fairseq/modules/segmentation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/fairseq/modules/segmentation/hierarchical_segmentation.py b/fairseq/modules/segmentation/hierarchical_segmentation.py new file mode 100644 index 0000000000..57ed6bbd8e --- /dev/null +++ b/fairseq/modules/segmentation/hierarchical_segmentation.py @@ -0,0 +1,329 @@ + +import torch +import numpy as np +from .segment_dict import * +from heapq import * +from torch.autograd import Function, Variable + +def variance(linearSum, squaresSum, size): + return np.sum((squaresSum / size) - np.square(linearSum / size)) # sum of "variance mse vector" + +def se(linearSum, squaresSum, size): # square error + return np.sum(squaresSum - np.square(linearSum) / size) # sum of "se vector" + +def varianceDiff(linearSum1, squaresSum1, size1, linearSum2, squaresSum2, size2): + return variance(linearSum1 + linearSum2, squaresSum1 + squaresSum2, size1 + size2) - variance(linearSum1, squaresSum1, size1) - variance(linearSum2, squaresSum2, size2) + +def seDiff(linearSum1, squaresSum1, size1, linearSum2, squaresSum2, size2): + return se(linearSum1 + linearSum2, squaresSum1 + squaresSum2, size1 + size2) - se(linearSum1, squaresSum1, size1) - se(linearSum2, squaresSum2, size2) + +def cosDist(linearSum1, squaresSum1, size1, linearSum2, squaresSum2, size2): # cosine distance + unscaledSim = np.dot(linearSum1, linearSum2) / (np.sqrt(np.dot(linearSum1, linearSum1)) * np.sqrt(np.dot(linearSum2, linearSum2))) + unscaledAsDist = -unscaledSim + 1. # change from similarity to distance; we mainly care about order for priority queue, for that any mapping reversing order is ok (low similarity = high distance) + # ^ here we have a change from [-1, 1] to [0, 2]; standard "cosine distance" + return unscaledAsDist * (size1 + size2) + # scaling so that big nonsense averaged almost-random segments don't appear as similar (randomnoise1 ~= randomnoise2) + # this is where changing form similarity to distance mapping can make a difference, but linear one seems ok + # this scaling is similar to the sum of distances of all elements to the average of the another segment and vice versa (can use sums instead of averages for cosine sim; + # but that's perhaps not exactly this sum as cosine_similarity ( (sum_i a_i) , x ) is not the same as (sum_i cosine_similarity ( a_i , x )) ) + # but the other one would be more expensive to compute + +def linRoundingLoss(mean, originals): + return torch.abs(originals - mean).sum() + +def varRoundingLoss(mean, originals): + return torch.mean(torch.square(originals - mean), dim=0).sum() + +def seRoundingLoss(mean, originals): + return torch.square(originals - mean).sum() + +def cosRoundingLoss(mean, originals): + unscaledSim = torch.matmul(mean, originals) / (torch.sqrt(torch.dot(mean, mean)) * torch.sqrt(torch.matmul(originals, originals))) + unscaledAsDist = -unscaledSim + 1. + return unscaledAsDist.sum() + +# [!] lines has to be a numpy array, np.sum() crashes if done on tensor +def hierarchicalSegmentation(lines, padMask=None, k=None, minSegmsPerLine=None, mergePriority="mse"): # k is sum of number of segments for all lines + + if mergePriority == "se": # var not divided by size, square error + costFun = seDiff #lambda linearSum1, squaresSum1, size1, linearSum2, squaresSum2, size2: seDiff(linearSum1, squaresSum1, size1, linearSum2, squaresSum2, size2) + elif mergePriority == "var": # var is mse + costFun = varianceDiff #lambda linearSum1, squaresSum1, size1, linearSum2, squaresSum2, size2: varianceDiff(linearSum1, squaresSum1, size1, linearSum2, squaresSum2, size2) + elif mergePriority == "cos": + costFun = cosDist #lambda linearSum1, squaresSum1, size1, linearSum2, squaresSum2, size2: cos(linearSum1, squaresSum1, size1, linearSum2, squaresSum2, size2) + else: + assert False + + # TODO check if tensor parts correctly taken etc. [!] + segmentsDict = SegmentDict(lines, padMask=padMask, minSegmsPerLine=minSegmsPerLine) + + # maybe will need to change this to arrays or so instead of dicts for efficiency + + # q for ranges to merge + q = [] + + # every pair added only one time; after merges will need to add both to right and to left + for segm in segmentsDict.getSegments(): + segmRight = segmentsDict.getSegmentRight(segm) + if segmRight is not None: + linSum1, sqSum1 = segmentsDict.getSegmentSums(segm) + linSum2, sqSum2 = segmentsDict.getSegmentSums(segmRight) + line1, left1, right1 = segm + line2, left2, right2 = segmRight + #oldVar1 = costFun(linSum1, sqSum1, right1 - left1 + 1) + #oldVar2 = costFun(linSum2, sqSum2, right2 - left2 + 1) + #mergedVariance = costFun(linSum1 + linSum2, sqSum1 + sqSum2, right2 - left1 + 1) + size1 = right1 - left1 + 1 + size2 = right2 - left2 + 1 + costDiff = costFun(linSum1, sqSum1, size1, linSum2, sqSum2, size2) + heappush(q, (costDiff, segm, segmRight)) + + varChanges = [] + merges = [] + + while len(q) and (k is None or segmentsDict.numSegments() > k): # will stop merging before k reached if minSegmsPerLine reached + + varChange, left, right = heappop(q) + merged = segmentsDict.mergeSegments(left, right) # checks if merge is valid + + if merged is None: # old merge possibility, now impossible (or minSegmsPerLine reached for this line) + continue + + varChanges.append(varChange) + merges.append((left, right)) + + toLeft = segmentsDict.getSegmentLeft(merged) + toRight = segmentsDict.getSegmentRight(merged) + linSumMerged, sqSumMerged = segmentsDict.getSegmentSums(merged) + lineMerged, leftMerged, rightMerged = merged + sizeMerged = rightMerged - leftMerged + 1 + #varMerged = costFun(linSumMerged, sqSumMerged, rightMerged - leftMerged + 1) + + if toLeft is not None: + linSum2, sqSum2 = segmentsDict.getSegmentSums(toLeft) + line2, left2, right2 = toLeft + size2 = right2 - left2 + 1 + #oldVar2 = costFun(linSum2, sqSum2, right2 - left2 + 1) + #mergedVariance = costFun(linSumMerged + linSum2, sqSumMerged + sqSum2, rightMerged - left2 + 1) + costDiff = costFun(linSumMerged, sqSumMerged, sizeMerged, linSum2, sqSum2, size2) + heappush(q, (costDiff, toLeft, merged)) + + if toRight is not None: + linSum2, sqSum2 = segmentsDict.getSegmentSums(toRight) + line2, left2, right2 = toRight + size2 = right2 - left2 + 1 + #oldVar2 = costFun(linSum2, sqSum2, right2 - left2 + 1) + #mergedVariance = costFun(linSumMerged + linSum2, sqSumMerged + sqSum2, right2 - leftMerged + 1) + costDiff = costFun(linSumMerged, sqSumMerged, sizeMerged, linSum2, sqSum2, size2) + heappush(q, (costDiff, merged, toRight)) + + return varChanges, merges, segmentsDict + +class HierarchicalSegmentationLayer(Function): + + @staticmethod + def flatten(x): + s = x.shape() + if len(s) < 3: + return x + if len(s) == 3: + return x.view(-1, s[2]) + assert False + + # perhaps that ^ is not needed, and restore_shapes also + + @staticmethod + def forward(ctx, inputGPU, padMask=None, k=None, allowKsumRange=None, minSegmsPerLine=None, mergePriority="se", shorteningPolicy="orig_len", roundingLossType=None): + # k for strict num of segments (SUM FOR ALL LINES), allowKsumRange for range OF SUM OF SEGMENTS IN ALL LINES and choosing 'best' split point + # min and max number of merges adjusted to what is possible - e.g. because of minSegmsPerLine + + assert k is None or allowKsumRange is None # mutually exclusive options + assert shorteningPolicy in ("shorten", "orig_len") # orig_len+guess_orig is only at the higher level + assert roundingLossType in ("se", "var", "lin", "cos", None) + if roundingLossType == "se": + roundingLossFun = seRoundingLoss + elif roundingLossType == "var": + roundingLossFun = varRoundingLoss + elif roundingLossType == "lin": + roundingLossFun = linRoundingLoss + elif roundingLossType == "cos": + roundingLossFun = cosRoundingLoss + else: + assert False + + # TODO if input only 2-dim, add another dimension possibly (W x H -> 1 x W x H, consistent with B x W x H - later assuming that in some places) + + inputDevice = inputGPU.device + padMaskInputDevice = padMask.device if padMask is not None else False + + # tensor to CPU (don't really need copy, will just need to put tensors in segmentsDict) + input = inputGPU.detach().to('cpu').numpy() + # https://discuss.pytorch.org/t/cant-convert-cuda-tensor-to-numpy-use-tensor-cpu-to-copy-the-tensor-to-host-memory-first/38301 , + # https://discuss.pytorch.org/t/what-is-the-cpu-in-pytorch/15007/3 + + varChanges, merges, segmentsDict = hierarchicalSegmentation(input, padMask=padMask, k=k, minSegmsPerLine=minSegmsPerLine, mergePriority=mergePriority) # won't modify input + #print("MERGES0: ", merges) + if allowKsumRange: # full merge done above, k=None, so each line now has minSegmsPerLine, but can also just get it from SegmDict - cleaner + begin, end = allowKsumRange + assert begin <= end + # [!] min and max number of merges adjusted to what is possible - e.g. because of minSegmsPerLine + beginIdx = max(0, min(len(varChanges) - 1, (segmentsDict.numSegments() + (len(varChanges) - 1) - end))) # max allowed num of segments, smallest num of merges; input.shape[0] is num of segments if all merges done + endIdx = max(0, min(len(varChanges) - 1, (segmentsDict.numSegments() + (len(varChanges) - 1) - begin))) # min allowed num of segments, biggest num of merges; input.shape[0] is num of segments if all merges done + #print("::::::::::", beginIdx, endIdx) + prefSums = [] + s = 0. + for chng in varChanges: + s += chng + prefSums.append(s) + best = -1 + where = -1 + #print("PREFSUMS: ", prefSums) + for i in range(beginIdx, min(endIdx+1, len(varChanges))): + sufSum = s - prefSums[i] # sum after this index + prefSum = prefSums[i] if prefSums[i] > 0. else .0000001 # don't div by 0 + # v the bigger the better split point; suffix div by prefix averages of variance change + here = (sufSum / (len(varChanges)-i)) / (prefSum / (i+1.)) + #print("!", i, ":", prefSum ,sufSum, here) + + if here > best: + best = here + where = i + if where == -1: + print("WARNING: problems choosing best num segments") + where = int((beginIdx + endIdx) // 2) + varChanges = varChanges[:where+1] # this one is not really needed + merges = merges[:where+1] + + finalSegments, segmentNumsInLines = SegmentDict.getFinalSegments(merges, input.shape[:2], padMask=padMask) + #print("MERGES: ", merges) + #print("FINAL SEGMENTS: ", finalSegments) + + maxSegments = max(segmentNumsInLines) + + if shorteningPolicy == "shorten": + segmented = np.full((input.shape[0], maxSegments, input.shape[2]), 0.) #torch.tensor(size=(input.shape[0], maxSegments, input.shape[2])).fill_(0.) + paddingMaskOut = np.full((input.shape[0], maxSegments), False) #torch.BoolTensor(size=(input.shape[0], maxSegments)).fill_(False) + for i, n in enumerate(segmentNumsInLines): + paddingMaskOut[i][n:] = True + resPadMask = torch.BoolTensor(paddingMaskOut).to(padMaskInputDevice) + else: + segmented = np.full(input.shape, 0.) + resPadMask = padMask + # can perhaps return a tensor with 1 at the beginning of the segments, -1 at the end, 0s elsewhere + segmentBorders = np.zeros((input.shape[0], input.shape[1]), dtype=np.int8) + roundingLoss = torch.tensor(0, dtype=torch.float32).requires_grad_(True).to(inputDevice) # TODO dtype (?) + for line, idxInLine in finalSegments.keys(): + line, begin, end = finalSegments[(line, idxInLine)] + if shorteningPolicy == "shorten": + segmented[line][idxInLine] = np.mean(input[line][begin:(end+1)], axis=0) #torch.mean(input[line][begin:(end+1)]) + else: + segmented[line][begin:(end+1)] = np.mean(input[line][begin:(end+1)], axis=0) + roundingLoss += roundingLossFun(torch.mean(inputGPU[line][begin:(end+1)], dim=0), inputGPU[line][begin:(end+1)]) + segmentBorders[line][end] = -1 + segmentBorders[line][begin] = 1 # [!] can be e.g. [...0, 0, 1, 1, ...] with segment of length 1 + # - marking begins when length 1 as * scaling doesn't need + (scale-1) there if logging only begins + + resOutput = torch.tensor(segmented, dtype=inputGPU.dtype).to(inputDevice) #if wasInputOnGPU else torch.tensor(segmented) #.requires_grad_(True) + # resPadMask created above, as for some reason torch.BoolTensor(paddingMaskOut).to(padMaskInputDevice) thrown an error if paddingMaskOut was a tensor on a correct device + segmentBorders = torch.IntTensor(segmentBorders).to(inputDevice) + + #print("********************", dir(ctx)) + #[not really needed] ctx.save_for_backward(padMask, resPadMask) + # save_for_backward is only for tensors / variables / stuff + if shorteningPolicy == "shorten": + ctx.shortened = True + else: + ctx.shortened = False + ctx.finalSegments = finalSegments + ctx.segmentNumsInLines = segmentNumsInLines + ctx.inputShape = input.shape + ctx.mark_non_differentiable(resPadMask) # can only pass torch variables here and only that makes sense + + #print("FINAL SEGMENTS: ", finalSegments, segmentNumsInLines) + + # with rounding loss None, will just return 0 + return resOutput, resPadMask, segmentBorders, roundingLoss #, finalSegments, segmentNumsInLines can only return torch variables... TODO maybe check how to fetch this info, but not sure if needed + + @staticmethod + def backward(ctx, dxThrough, outPadMask=None, segmentBorders=None, roundingLoss=None): #, finalSegments=None, segmentNumsInLines=None): + + dxThroughDevice = dxThrough.device + + #[not really needed] paddingMask, paddingMaskOut = ctx.saved_tensors + dx = torch.empty(size=ctx.inputShape, dtype=dxThrough.dtype).fill_(0.).to('cpu') + + wasShortened = ctx.shortened + + for line, idxInLine in ctx.finalSegments.keys(): + line, begin, end = ctx.finalSegments[(line, idxInLine)] + if wasShortened: + dx[line][begin:(end+1)] = dxThrough[line][idxInLine] / (end - begin + 1) + else: + dx[line][begin:(end+1)] = (dxThrough[line][begin:(end+1)].sum(dim=0)) / (end - begin + 1) + + dx = dx.to(dxThroughDevice) + + return dx, None, None, None, None, None, None, None + + +if __name__ == '__main__': + # import ptvsd + # ptvsd.enable_attach(('0.0.0.0', 7309)) + # print("Attach debugger now") + # ptvsd.wait_for_attach() + + # run from .. with python -m segmentation.hierarchical_variance_segmentation + + tensor = torch.tensor([[[1,2],[1,2],[3,4],[3,4],[3,4],[8,9],[8,9]], [[1,2],[1,2],[3,4],[3,4],[3,4],[8,9],[8,9]]], dtype=torch.float64).requires_grad_(True) + print(tensor[0][1]) + print(hierarchicalSegmentation(tensor.detach().numpy(), padMask=None, k=4, minSegmsPerLine=None, mergePriority="se")) # pre-last merge in each line (merging (0,1) and (2,4)) should be 1.92 if summing 'variance vectors' + print(hierarchicalSegmentation(tensor.detach().numpy(), padMask=None, k=2, minSegmsPerLine=None, mergePriority="var")) # pre-last merge in each line (merging (0,1) and (2,4)) should be 1.92 if summing 'variance vectors' + + print("-------------------------- torch ---------------------------") + # (tensor, padMask, k, kSumRange) + resOutput, resPadMask, borders, roundingLoss = HierarchicalSegmentationLayer.apply(tensor, torch.tensor([[True, False, False, False, False, False, False], [False, False, False, False, False, False, True]]), None, (2,5), None, "var", "shorten", None) #(2, 5)) # can;t have keyword args for torch Functions... + print(resOutput) + print(resPadMask) + print(borders) + #print(finalSegments) + #print(segmentNumsInLines) + #loss = Variable(resOutput, requires_grad=True) + resOutput.sum().backward() # .backward() needs loss to be a number (tensor of size (1,)) + print(tensor.grad) + + print("-------------------------- torch2 ---------------------------") + # (tensor, padMask, k, kSumRange) + tensor.grad.data.zero_() + resOutput, resPadMask, borders, roundingLoss = HierarchicalSegmentationLayer.apply(tensor, torch.tensor([[True, False, False, False, False, False, False], [False, False, False, False, False, False, True]]), 3, None, None, "se", "shorten", None) #(2, 5)) # can;t have keyword args for torch Functions... + print(resOutput) + print(resPadMask) + print(borders) + #print(finalSegments) + #print(segmentNumsInLines) + #loss = Variable(resOutput, requires_grad=True) + resOutput.sum().backward() # .backward() needs loss to be a number (tensor of size (1,)) + print(tensor.grad) + + print("-------------------------- torch3 ---------------------------") + # (tensor, padMask, k, kSumRange) + tensor.grad.data.zero_() + resOutput, resPadMask, borders, roundingLoss = HierarchicalSegmentationLayer.apply(tensor, torch.tensor([[True, False, False, False, False, False, False], [False, False, False, False, False, False, True]]), 3, None, 2, "se", "shorten", None) #(2, 5)) # can;t have keyword args for torch Functions... + print(resOutput) + print(resPadMask) + print(borders) + # [!] here will return 4 segments instead of specified 3, because of specified minSegmsPerLine + + resOutput.sum().backward() # .backward() needs loss to be a number (tensor of size (1,)) + print(tensor.grad) + + print("-------------------------- torch4 ---------------------------") + # (tensor, padMask, k, kSumRange) + tensor.grad.data.zero_() + resOutput, resPadMask, borders, roundingLoss = HierarchicalSegmentationLayer.apply(tensor, torch.tensor([[True, False, False, False, False, False, False], [False, False, False, False, False, False, True]]), 3, None, 2, "se", "orig_len", None) #(2, 5)) # can;t have keyword args for torch Functions... + print(resOutput) + print(resPadMask) + print(borders) + # [!] here will return 4 segments instead of specified 3, because of specified minSegmsPerLine + + resOutput.sum().backward() # .backward() needs loss to be a number (tensor of size (1,)) + print(tensor.grad) \ No newline at end of file diff --git a/fairseq/modules/segmentation/segment_dict.py b/fairseq/modules/segmentation/segment_dict.py new file mode 100644 index 0000000000..2da056b15d --- /dev/null +++ b/fairseq/modules/segmentation/segment_dict.py @@ -0,0 +1,138 @@ + +import numpy as np + +class SegmentDict: + + def __init__(self, lines, padMask=None, minSegmsPerLine=None): # lines assumed to be of shape [#lines x line_len * k] + # (line#, place in line): (begin in line, end in line, sum(x), sum(x^2)) ; sums are possibly vectors + self._dct = {(i, j): (j, j, lines[i][j], np.square(lines[i][j])) for i in range(len(lines)) for j in range(len(lines[i])) if padMask is None or not padMask[i][j]} + self._size = len(self._dct) # sometimes 1 (now all segments have 1 entry), sometimes later 2 entries per segment - better keep a counter + + self._line_segms = [0 for i in range(lines.shape[0])] + for line, _ in self._dct: + self._line_segms[line] += 1 + + self.minSegmsPerLine = minSegmsPerLine + + # there is a 'segment' implicit format (tuple) used: (line#, leftIndex(begin), rightIndex(end)) + + def numSegments(self): + return self._size + + def segmentInDict(self, segment): + line, leftIdx, rightIdx = segment + if (line, leftIdx) not in self._dct: + return False + leftIdxFromDict, rightIdxFromDict, _, _ = self._dct[(line, leftIdx)] + return leftIdx == leftIdxFromDict and rightIdx == rightIdxFromDict + # (line, leftIdx) in dct can be for right-range leftIdx with different leftIdxFromDict for merged segment + + def removeSegment(self, segment): + line, leftIdx, rightIdx = segment + wasThere = False + if (line, leftIdx) in self._dct: + del self._dct[(line, leftIdx)] + wasThere = True + if (line, rightIdx) in self._dct: + del self._dct[(line, rightIdx)] + wasThere = True + if wasThere: + self._size -= 1 + self._line_segms[line] -= 1 + + def mergeSegments(self, segment1, segment2): + line1, left1, right1 = segment1 + line2, left2, right2 = segment2 + if not self.segmentInDict(segment1) or not self.segmentInDict(segment2) \ + or line1 != line2 or right1 + 1 != left2 \ + or (self.minSegmsPerLine and self._line_segms[line1] <= self.minSegmsPerLine): + # not subsequent or too few segments in line + return None + linearSum1, squaresSum1 = self.getSegmentSums(segment1) + linearSum2, squaresSum2 = self.getSegmentSums(segment2) + # remove old segments; will update _size + self.removeSegment(segment1) + self.removeSegment(segment2) + #assert (line1, left1) not in self._dct + #assert (line1, right1) not in self._dct + #assert (line1, left2) not in self._dct + #assert (line1, right2) not in self._dct + # add a new merged one; need to update _size by hand + self._dct[(line1, left1)] = (left1, right2, linearSum1 + linearSum2, squaresSum1 + squaresSum2) + self._dct[(line1, right2)] = (left1, right2, linearSum1 + linearSum2, squaresSum1 + squaresSum2) + #print(segment1, segment2, "->", (line1, left1, right2)) + self._size += 1 + self._line_segms[line1] += 1 + return (line1, left1, right2) + + def getSegments(self): + res = [] + for (line, leftIdx) in self._dct.keys(): + begin, end, _, _ = self._dct[(line, leftIdx)] + res.append((line, begin, end)) + return res + + def getSegmentLeft(self, segment): + if not self.segmentInDict(segment): + return None + line, left, right = segment + if (line, left - 1) not in self._dct: + return None + segmLeft, segmRight, _, _ = self._dct[(line, left - 1)] + #print(left, right, "!", segmLeft, segmRight, left - 1) + #assert segmRight == left - 1 + #assert self._dct[(line, segmLeft)][0] == segmLeft and self._dct[(line, segmLeft)][1] == segmRight + return (line, segmLeft, segmRight) + + def getSegmentRight(self, segment): + if not self.segmentInDict(segment): + return None + line, left, right = segment + if (line, right + 1) not in self._dct: + return None + segmLeft, segmRight, _, _ = self._dct[(line, right + 1)] + #assert segmLeft == right + 1 + #assert self._dct[(line, segmRight)][0] == segmLeft and self._dct[(line, segmRight)][1] == segmRight + return (line, segmLeft, segmRight) + + def getSegmentSums(self, segment): + if not self.segmentInDict(segment): + return None + line, left, _ = segment + _, _, linearSum, squaresSum = self._dct[(line, left)] + return (linearSum, squaresSum) + + @staticmethod + def getFinalSegments(merges, shape, padMask=None): # shape needs to be B x W, without height! + + visited = np.zeros(shape, dtype=np.int32) + finalSegments = [] + for i in range(len(merges)-1,-1,-1): + leftSegm, rightSegm = merges[i] + line, beginLeft, endLeft = leftSegm + if visited[line][beginLeft] != 0: + continue # merge already seen + _, beginRight, endRight = rightSegm + finalSegments.append((line, beginLeft, endRight)) + visited[line][beginLeft:(endRight+1)] = 1 + + # add length-1 segments that are there not padded but were not a part of any merge + for i in range(visited.shape[0]): + for j in range(visited.shape[1]): + if not visited[i][j] and (padMask is None or not padMask[i][j]): + finalSegments.append((i, j, j)) + + lineCounter = 0 + prevLine = 0 # don't append useless 0 at the beginning + res = {} # {(line, #ofSegmentInLine): (line, beginIdx, endIdx)} + segmentsInLines = [] # numbers of segments in lines + for line, begin, end in sorted(finalSegments): + if line != prevLine: + prevLine = line + segmentsInLines.append(lineCounter) + lineCounter = 0 + res[(line, lineCounter)] = (line, begin, end) + lineCounter += 1 + segmentsInLines.append(lineCounter) + + return res, segmentsInLines # there will be always at least 1 segment in a line \ No newline at end of file diff --git a/fairseq/tasks/scribblelens.py b/fairseq/tasks/scribblelens.py index cd910e2d3b..00b085fc92 100644 --- a/fairseq/tasks/scribblelens.py +++ b/fairseq/tasks/scribblelens.py @@ -98,7 +98,9 @@ def load_dataset(self, split, **kwargs): split (str): name of the split (e.g., train, valid, test) """ - vocab_path = self.args.vocab_path if self.args.vocab_path is not None else self.args.data + '/tasman.alphabet.plus.space.mode5.json' + vocab_path = self.args.vocab_path if self.args.vocab_path is not None else '' + # [now file in default location not used, needs to be specified, otherwise trying to construct vocab from scratch] + # self.args.data + '/tasman.alphabet.plus.space.mode5.json' if not self.args.labels: self.datasets[split] = FileHandwritingDataset( diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 19ca213d55..685531f77b 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -935,6 +935,10 @@ def _prepare_sample(self, sample): sample["target"], device=self.last_device ) else: + # v needed if non-tensor stuff in sample (e.g. metadata), but kept tensors for safety + # for key in sample: + # if torch.is_tensor(key): + # sample[key] = utils.move_to_cuda(sample[key]) sample = utils.move_to_cuda(sample) def apply_half(t): diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index e1af605348..cda45607fd 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -213,6 +213,8 @@ def train( should_stop = False num_updates = trainer.get_num_updates() for i, samples in enumerate(progress): + for sample in samples: + sample["epoch"] = torch.tensor(epoch_itr.epoch, dtype=torch.int16) with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i ): @@ -352,6 +354,7 @@ def validate( # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for sample in progress: + sample["epoch"] = torch.tensor(epoch_itr.epoch, dtype=torch.int16) trainer.valid_step(sample) # log validation stats diff --git a/uwr_related/experiments/pp/gen_sim_imgs.py b/uwr_related/experiments/pp/gen_sim_imgs.py new file mode 100644 index 0000000000..a413e1e71c --- /dev/null +++ b/uwr_related/experiments/pp/gen_sim_imgs.py @@ -0,0 +1,106 @@ + +import sys +import os +import numpy as np +from PIL import Image, ImageDraw +#import matplotlib.pyplot as plt + +assert len(sys.argv) == 2 or len(sys.argv) == 3 # pass directory with data as an arg and also (default 50) for how many pixels there should be a helper grid line +#print(sys.argv[0], sys.argv[1]) + +# TODO a mode with grid on segment borders +gridHz = int(sys.argv[2]) if len(sys.argv) == 3 else 50 + +# this should map for whole array, can e.g. use numpy etc. then +def mapSimFromDist(numArr, minVal=0., maxVal=1.): + # this one only makes sense with >= 0 values + # maps to one colour + # TODO logarithmic option?? + #minRev = 1./maxVal + #maxRev = 1./maxVal + #mapped = (1./num - minRev) / maxRev + if maxVal == 0: + maxVal = 1 + return (float(maxVal) - (numArr - minVal)) / (maxVal - minVal) + +dct = {} +for f in os.scandir(sys.argv[1]): + els = f.name.split("_") + t = els[0] + rest = "_".join(els[1:]) + dct[(t,rest)] = f.name + +print(dct) + +print("===================", mapSimFromDist(0.), mapSimFromDist(1.)) + +for t, rest in dct: + if t != "input": + continue + if ("features", rest) not in dct: + continue + inputFile = sys.argv[1] + "/" + t + "_" + rest + featuresFile = sys.argv[1] + "/" + "features" + "_" + rest + print("processing", inputFile, featuresFile) + inputArr = np.load(inputFile) + featuresArr = np.load(featuresFile).T + + # calculating array of distances between the representations; TODO add some params for different options, also similarity without distance in between + inputSq = np.square(inputArr).sum(axis=0) + featuresSq = np.square(featuresArr).sum(axis=0) + #print(inputArr.shape, inputSq.shape, featuresArr.shape, featuresArr.T.shape, featuresSq.shape, featuresSq[np.newaxis,:].T.shape) + distArr = featuresSq + featuresSq[np.newaxis,:].T - 2. * np.matmul(featuresArr.T, featuresArr) + #print(distArr.shape) + + bigArr = np.zeros((inputArr.shape[0] + 3 + inputArr.shape[1], inputArr.shape[0] + 3 + inputArr.shape[1], 3)) + #print(bigArr.shape, inputArr.shape) + + # adding input images on the top and on the left and blue lines to an image + bigArr[:inputArr.shape[0], (inputArr.shape[0]+3):(inputArr.shape[0]+3+inputArr.shape[1]), :] = inputArr[:,:,np.newaxis] + bigArr[(inputArr.shape[0]+3):(inputArr.shape[0]+3+inputArr.shape[1]), :inputArr.shape[0], :] = np.flip(inputArr, axis=0).T[:,:,np.newaxis] + bigArr[inputArr.shape[0]:(inputArr.shape[0] + 3), (inputArr.shape[0]):(inputArr.shape[0]+3+inputArr.shape[1]), 2] = 1. + bigArr[(inputArr.shape[0]):(inputArr.shape[0]+3+inputArr.shape[1]), inputArr.shape[0]:(inputArr.shape[0] +3), 2] = 1. + a1 = ((inputArr.shape[0], inputArr.shape[0] + 3), (inputArr.shape[0], inputArr.shape[0]+3+inputArr.shape[1])) + a2 = ((inputArr.shape[0], inputArr.shape[0]+3+inputArr.shape[1]), (inputArr.shape[0], inputArr.shape[0] +3)) + #print("!!!", distArr[:2,:2]) + #sim = plt.imshow(distArr, cmap='viridis', ) + scaleFloat = float(inputArr.shape[1]) / float(distArr.shape[0]) + #print("---->", scaleFloat, inputArr.shape[1], distArr.shape[0]) + + # creating similarity array from distance array; TODO as mentioned where creating distArr, add some params for different options, also without dist, like e.g. cosine sim + minVal = distArr.min() + maxVal = distArr.max() + simArr = mapSimFromDist(distArr, minVal, maxVal) + + # converting stuff to 0-255 (but not ints where not needed yet) and increasing size 2x + bigArr = bigArr * 255. # scaling here, as similarity scaled separately below + #print("!!!", bigArr.shape, np.ones((2,2)).shape) + bigArr = bigArr.repeat(2, axis=0).repeat(2, axis=1) + #print("!!!", bigArr.shape) + bigArr[2*(inputArr.shape[0] + 3):, 2*(inputArr.shape[0] + 3):, 0] = np.asarray(Image.fromarray(np.array(simArr*255., dtype=np.int8)).resize((inputArr.shape[1]*2, inputArr.shape[1]*2), resample=Image.NEAREST)) + + # from here stuff is 2x bigger (as grid lines were otherwise too big) + + # choosing helper grid positions to plot + if ("segmentborders", rest) in dct and len(sys.argv) == 2: # ONLY plot as segments if no grid density specified + bordersFile = featuresFile = sys.argv[1] + "/" + "segmentborders" + "_" + rest + bordersArr = np.load(bordersFile) + #print("--------", bordersArr) + gridBorders = [ 2*(inputArr.shape[0] + 3) + 2*int(round(j*scaleFloat)) for j, k in enumerate(bordersArr) if k == 1] + #print("[][][][]", gridBorders) + else: + gridBorders = range(2*(inputArr.shape[0]+3), 2*(inputArr.shape[0]+3+inputArr.shape[1]), gridHz*2) + + # plotting helper grid + for i in gridBorders: + # grid helper + bigArr[i, 2*(inputArr.shape[0]+3):, :] = 255 + bigArr[2*(inputArr.shape[0]+3):, i, :] = 255 + # for i in range(inputArr.shape[1]): + # for j in range(inputArr.shape[1]): + # bigArr[inputArr.shape[0] + 3 + i, inputArr.shape[0] + 3 + j][0] = mapSimFromDist(distArr[int(i / scaleFloat)][int(j / scaleFloat)], minVal, maxVal) + + # saving image from array + img = Image.fromarray(np.array(bigArr, dtype=np.int8), 'RGB') #.resize((bigArr.shape[0]*2, bigArr.shape[1]*2)) # PIL needs EXPLICIT int8, won't understand that int32 of values <256 is int8 + img.save(sys.argv[1] + "/" + "visualization" + "_" + rest.split(".")[0] + ".png") + #img.show() diff --git a/uwr_related/test_cmd_scribble.sh b/uwr_related/test_cmd_scribble.sh index 0fd001454a..316218e282 100755 --- a/uwr_related/test_cmd_scribble.sh +++ b/uwr_related/test_cmd_scribble.sh @@ -41,20 +41,22 @@ python train.py --distributed-world-size 1 --update-freq 2 \ /pio/scratch/1/i283340/MGR/NewSetup/DistSup/data `#path to Scribblelens data folder` \ --vocab-path ./fairseq/data/handwriting/tasman.alphabet.plus.space.mode5.json `#alphabet file` \ --save-dir ../try_sl1 --num-workers 0 \ - --task scribblelens --criterion wav2vec --arch wav2vec2_scribblelens \ + --task scribblelens --criterion wav2vec `#--pass-metadata` --arch wav2vec2_scribblelens \ --valid-subset test --pad-to-multiples-of 4 `#--max-sample-size 256` \ --log-keys '["prob_perplexity","code_perplexity","temp"]' --quantize-targets --extractor-mode default \ --conv-feature-layers '[(64, (3, 3), (1, 2), (1, 1)), (128, (5, 5), (2, 2), (2, 2)), (256, (3,3), (1, 1), (1, 1)), (256, (3,3), (1, 2), (1, 1)), (512, (3,3), (1, 1), (1, 1)), (512, (3,3), (1, 2), (1, 1)), (512, (3,2), (2, 1), (1, 0))]' \ --final-dim 256 \ --latent-vars 320 --latent-groups 2 --latent-temp '(2,0.5,0.999995)' --infonce \ --optimizer adam --adam-betas '(0.9,0.98)' --adam-eps 1e-06 --lr-scheduler polynomial_decay \ - --total-num-update 400000 --lr 0.0005 --warmup-updates 32000 \ + --total-num-update 40000 --lr 0.0005 --warmup-updates 3000 `#32000 is too much for scribblelens, more than twice as far as it collapses, same 400000 updates` \ --mask-length 10 --mask-prob 0.65 --mask-selection static --mask-other 0 \ --encoder-layerdrop 0.05 --dropout-input 0.1 --dropout-features 0.1 --feature-grad-mult 0.1 \ - --loss-weights '[0.1, 10]' --conv-pos 128 --conv-pos-groups 16 \ + --loss-weights '[0.1, 10]' `#'[0.1, 10, 3]'` --conv-pos 128 --conv-pos-groups 16 \ --num-negatives 100 --cross-sample-negatives 0 \ `#--max-sample-size 250000 --min-sample-size 32000` \ --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --max-tokens 10000 --max-update 400000 \ --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d \ --labels `#can be removed for no labels` \ + `#--segm-log-dir ../imgs3 --repr-data-log-dir ../repr3 --random-log-freq 0.0001 --log-ids =:715,%:1000:123` \ + `#--segm hier:se:none:shorten:2.5-3.5 # optional segmentation` \ --enable-padding # crashes without that, needs to make all lines same-size \ No newline at end of file