Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] convert model for 1D inputs #1529

Open
pfeatherstone opened this issue Nov 2, 2022 · 7 comments
Open

[FEATURE] convert model for 1D inputs #1529

pfeatherstone opened this issue Nov 2, 2022 · 7 comments
Labels
enhancement New feature or request

Comments

@pfeatherstone
Copy link

Is there a way to convert timm models for 1D inputs?
I realize that a 1D tensor with shape [B,C,S] can be reshaped to [B,C,1,S] or [B,C,S,1], but then the filters are unnecessarily large.
Maybe for transformer models this is easier since patch embeddings are linearized but it would be cool if CNNs could also be converted to 1D.
Cheers

@pfeatherstone pfeatherstone added the enhancement New feature or request label Nov 2, 2022
@rwightman
Copy link
Collaborator

@pfeatherstone actually had a use case for doing this for 3d come up recently, think it might be a need extension, would make sense to cover 1d as well if I did that, curious what the use case is/was?

@pfeatherstone
Copy link
Author

I regularly build classifier models and feature extractors for 1D time series

@pfeatherstone
Copy link
Author

pfeatherstone commented Apr 18, 2023

So far i have this conversion function:

def _SqueezeExciteForwardMethod(self, x):
    x_se = x.mean((2,), keepdim=True)
    x_se = self.conv_reduce(x_se)
    x_se = self.act1(x_se)
    x_se = self.conv_expand(x_se)
    return x * self.gate(x_se)
    
class GlobalAvgPool1d(nn.Module):
    def __init__(self):
        super(GlobalAvgPool1d, self).__init__()

    def forward(self, x):
        return x.flatten(2).mean(-1)
    
def convert2DModelTo1D(module):
    # Reset global functions
    timm.models.efficientnet_blocks.SqueezeExcite.forward = _SqueezeExciteForwardMethod

    module_output = module

    if hasattr(module, "global_pool"):
        module.global_pool = GlobalAvgPool1d()

    if isinstance(module, nn.BatchNorm2d):
        module_output = nn.BatchNorm1d(
            module.num_features,
            module.eps,
            module.momentum,
            module.affine,
            module.track_running_stats)
        if module.affine:
            with torch.no_grad():
                module_output.weight    = module.weight
                module_output.bias      = module.bias
        module_output.running_mean          = module.running_mean
        module_output.running_var           = module.running_var
        module_output.num_batches_tracked   = module.num_batches_tracked
        if hasattr(module, "qconfig"):
            module_output.qconfig = module.qconfig

    elif isinstance(module, nn.Conv2d):
        module_output = nn.Conv1d(
            module.in_channels,
            module.out_channels,
            module.kernel_size[0],
            module.stride[0],
            module.padding[0],
            module.dilation[0],
            module.groups,
            hasattr(module, "bias"),
            module.padding_mode)
        with torch.no_grad():
            module_output.weight.copy_(module.weight.mean(-1))
            if hasattr(module, "bias"):
                module_output.bias = module.bias
         
    for name, child in module.named_children():
        module_output.add_module(name, convert2DModelTo1D(child))
    del module
    return module_output

This works for efficientnet_b0 so far.

@pfeatherstone
Copy link
Author

This works with a lot of CNN models. Maybe the better approach would be to have a flag in create_model() which dictated whether to build a 1D, 2D or 3D model. But for now this is sufficient.

@rwightman
Copy link
Collaborator

@pfeatherstone thanks for sharing that snippet, I think initial approach would be to try as a helper that mutates an existing model like above, adding arguments through the create_model fn is a bigger commitement that adds long term maintenance burden if it doesn't end up being used much...

@pfeatherstone
Copy link
Author

I have to say, the results were really poor. The better option was to concert my 1D data into 2D by using torch.stft() then using a normal 2D model...

@pfeatherstone
Copy link
Author

Tried this again optimistically thinking I would get better results. Nope...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants