-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
976 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,7 @@ | ||
# PaSST | ||
# PaSST: Efficient Training of Audio Transformers with Patchout | ||
|
||
|
||
This is the implementation for [Efficient Training of Audio Transformers with Patchout](https://arxiv.org/abs/2110.05069) | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,261 @@ | ||
""" | ||
Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py | ||
""" | ||
import math | ||
import warnings | ||
from copy import deepcopy | ||
|
||
import torch | ||
from timm.models.helpers import load_pretrained | ||
from torch import nn | ||
|
||
|
||
def overlay_external_default_cfg(default_cfg, kwargs): | ||
""" Overlay 'external_default_cfg' in kwargs on top of default_cfg arg. | ||
""" | ||
external_default_cfg = kwargs.pop('external_default_cfg', None) | ||
if external_default_cfg: | ||
default_cfg.pop('url', None) # url should come from external cfg | ||
default_cfg.pop('hf_hub', None) # hf hub id should come from external cfg | ||
default_cfg.update(external_default_cfg) | ||
|
||
|
||
def filter_kwargs(kwargs, names): | ||
if not kwargs or not names: | ||
return | ||
for n in names: | ||
kwargs.pop(n, None) | ||
|
||
|
||
def set_default_kwargs(kwargs, names, default_cfg): | ||
for n in names: | ||
# for legacy reasons, model __init__args uses img_size + in_chans as separate args while | ||
# default_cfg has one input_size=(C, H ,W) entry | ||
if n == 'img_size': | ||
input_size = default_cfg.get('input_size', None) | ||
if input_size is not None: | ||
assert len(input_size) == 3 | ||
kwargs.setdefault(n, input_size[-2:]) | ||
elif n == 'in_chans': | ||
input_size = default_cfg.get('input_size', None) | ||
if input_size is not None: | ||
assert len(input_size) == 3 | ||
kwargs.setdefault(n, input_size[0]) | ||
else: | ||
default_val = default_cfg.get(n, None) | ||
if default_val is not None: | ||
kwargs.setdefault(n, default_cfg[n]) | ||
|
||
|
||
def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter): | ||
""" Update the default_cfg and kwargs before passing to model | ||
FIXME this sequence of overlay default_cfg, set default kwargs, filter kwargs | ||
could/should be replaced by an improved configuration mechanism | ||
Args: | ||
default_cfg: input default_cfg (updated in-place) | ||
kwargs: keyword args passed to model build fn (updated in-place) | ||
kwargs_filter: keyword arg keys that must be removed before model __init__ | ||
""" | ||
# Overlay default cfg values from `external_default_cfg` if it exists in kwargs | ||
overlay_external_default_cfg(default_cfg, kwargs) | ||
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs) | ||
default_kwarg_names = ('num_classes', 'global_pool', 'in_chans') | ||
if default_cfg.get('fixed_input_size', False): | ||
# if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size | ||
default_kwarg_names += ('img_size',) | ||
set_default_kwargs(kwargs, names=default_kwarg_names, default_cfg=default_cfg) | ||
# Filter keyword args for task specific model variants (some 'features only' models, etc.) | ||
filter_kwargs(kwargs, names=kwargs_filter) | ||
|
||
|
||
|
||
def drop_path(x, drop_prob: float = 0., training: bool = False): | ||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | ||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, | ||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... | ||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for | ||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use | ||
'survival rate' as the argument. | ||
""" | ||
if drop_prob == 0. or not training: | ||
return x | ||
keep_prob = 1 - drop_prob | ||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets | ||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) | ||
random_tensor.floor_() # binarize | ||
output = x.div(keep_prob) * random_tensor | ||
return output | ||
|
||
|
||
class DropPath(nn.Module): | ||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | ||
""" | ||
def __init__(self, drop_prob=None): | ||
super(DropPath, self).__init__() | ||
self.drop_prob = drop_prob | ||
|
||
def forward(self, x): | ||
return drop_path(x, self.drop_prob, self.training) | ||
|
||
|
||
|
||
from torch.nn.init import _calculate_fan_in_and_fan_out | ||
|
||
|
||
def _no_grad_trunc_normal_(tensor, mean, std, a, b): | ||
# Cut & paste from PyTorch official master until it's in a few official releases - RW | ||
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf | ||
def norm_cdf(x): | ||
# Computes standard normal cumulative distribution function | ||
return (1. + math.erf(x / math.sqrt(2.))) / 2. | ||
|
||
if (mean < a - 2 * std) or (mean > b + 2 * std): | ||
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " | ||
"The distribution of values may be incorrect.", | ||
stacklevel=2) | ||
|
||
with torch.no_grad(): | ||
# Values are generated by using a truncated uniform distribution and | ||
# then using the inverse CDF for the normal distribution. | ||
# Get upper and lower cdf values | ||
l = norm_cdf((a - mean) / std) | ||
u = norm_cdf((b - mean) / std) | ||
|
||
# Uniformly fill tensor with values from [l, u], then translate to | ||
# [2l-1, 2u-1]. | ||
tensor.uniform_(2 * l - 1, 2 * u - 1) | ||
|
||
# Use inverse cdf transform for normal distribution to get truncated | ||
# standard normal | ||
tensor.erfinv_() | ||
|
||
# Transform to proper mean, std | ||
tensor.mul_(std * math.sqrt(2.)) | ||
tensor.add_(mean) | ||
|
||
# Clamp to ensure it's in the proper range | ||
tensor.clamp_(min=a, max=b) | ||
return tensor | ||
|
||
|
||
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): | ||
# type: (Tensor, float, float, float, float) -> Tensor | ||
r"""Fills the input Tensor with values drawn from a truncated | ||
normal distribution. The values are effectively drawn from the | ||
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` | ||
with values outside :math:`[a, b]` redrawn until they are within | ||
the bounds. The method used for generating the random values works | ||
best when :math:`a \leq \text{mean} \leq b`. | ||
Args: | ||
tensor: an n-dimensional `torch.Tensor` | ||
mean: the mean of the normal distribution | ||
std: the standard deviation of the normal distribution | ||
a: the minimum cutoff value | ||
b: the maximum cutoff value | ||
Examples: | ||
>>> w = torch.empty(3, 5) | ||
>>> nn.init.trunc_normal_(w) | ||
""" | ||
return _no_grad_trunc_normal_(tensor, mean, std, a, b) | ||
|
||
|
||
def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): | ||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) | ||
if mode == 'fan_in': | ||
denom = fan_in | ||
elif mode == 'fan_out': | ||
denom = fan_out | ||
elif mode == 'fan_avg': | ||
denom = (fan_in + fan_out) / 2 | ||
|
||
variance = scale / denom | ||
|
||
if distribution == "truncated_normal": | ||
# constant is stddev of standard normal truncated to (-2, 2) | ||
trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) | ||
elif distribution == "normal": | ||
tensor.normal_(std=math.sqrt(variance)) | ||
elif distribution == "uniform": | ||
bound = math.sqrt(3 * variance) | ||
tensor.uniform_(-bound, bound) | ||
else: | ||
raise ValueError(f"invalid distribution {distribution}") | ||
|
||
|
||
def lecun_normal_(tensor): | ||
variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') | ||
|
||
|
||
|
||
def build_model_with_cfg( | ||
model_cls, | ||
variant: str, | ||
pretrained: bool, | ||
default_cfg: dict, | ||
model_cfg= None, | ||
feature_cfg= None, | ||
pretrained_strict: bool = True, | ||
pretrained_filter_fn = None, | ||
pretrained_custom_load = False, | ||
kwargs_filter = None, | ||
**kwargs): | ||
""" Build model with specified default_cfg and optional model_cfg | ||
This helper fn aids in the construction of a model including: | ||
* handling default_cfg and associated pretained weight loading | ||
* passing through optional model_cfg for models with config based arch spec | ||
* features_only model adaptation | ||
* pruning config / model adaptation | ||
Args: | ||
model_cls (nn.Module): model class | ||
variant (str): model variant name | ||
pretrained (bool): load pretrained weights | ||
default_cfg (dict): model's default pretrained/task config | ||
model_cfg (Optional[Dict]): model's architecture config | ||
feature_cfg (Optional[Dict]: feature extraction adapter config | ||
pretrained_strict (bool): load pretrained weights strictly | ||
pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights | ||
pretrained_custom_load (bool): use custom load fn, to load numpy or other non PyTorch weights | ||
kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model | ||
**kwargs: model args passed through to model __init__ | ||
""" | ||
pruned = kwargs.pop('pruned', False) | ||
features = False | ||
feature_cfg = feature_cfg or {} | ||
default_cfg = deepcopy(default_cfg) if default_cfg else {} | ||
update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter) | ||
default_cfg.setdefault('architecture', variant) | ||
|
||
# Setup for feature extraction wrapper done at end of this fn | ||
if kwargs.pop('features_only', False): | ||
features = True | ||
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4)) | ||
if 'out_indices' in kwargs: | ||
feature_cfg['out_indices'] = kwargs.pop('out_indices') | ||
|
||
# Build the model | ||
model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs) | ||
model.default_cfg = default_cfg | ||
|
||
|
||
# For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats | ||
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) | ||
if pretrained: | ||
if pretrained_custom_load: | ||
load_custom_pretrained(model) | ||
else: | ||
load_pretrained( | ||
model, | ||
num_classes=num_classes_pretrained, | ||
in_chans=kwargs.get('in_chans', 3), | ||
filter_fn=pretrained_filter_fn, | ||
strict=pretrained_strict) | ||
return model | ||
|
||
|
Oops, something went wrong.