Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support model in_chans not equal to pre-trained weights in_chans #2324

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

keves1
Copy link

@keves1 keves1 commented Sep 27, 2024

This PR addresses #2289.

I've made a good start on this and am creating this draft PR to get some feedback before continuing. I have created a load_pretrained function that will copy weights as done by timm, but which supports any number of in_chans of the weights, not just 3. So far I just implemented this new functionality in resnet50 and will expand that to other models if this is a good approach.

With this change, you can create a model like this, in this case the two channels will be copied into 4:

model = resnet50(weights=ResNet50_Weights.SENTINEL1_ALL_MOCO, in_chans=4)

I created a new file torchgeo/models/utils.py where I put the load_pretrained function since I didn't see any suitable existing file. Is there a better place for it?

@github-actions github-actions bot added the models Models and pretrained weights label Sep 27, 2024
@keves1
Copy link
Author

keves1 commented Sep 27, 2024

@microsoft-github-policy-service agree

@adamjstewart
Copy link
Collaborator

I have created a load_pretrained function that will copy weights as done by timm, but which supports any number of in_chans of the weights, not just 3.

Why can't we use timm.models.helpers.load_pretrained instead of writing our own custom code?

@adamjstewart adamjstewart modified the milestones: 0.7.0, 0.6.1 Sep 28, 2024
@keves1
Copy link
Author

keves1 commented Sep 28, 2024

Why can't we use timm.models.helpers.load_pretrained instead of writing our own custom code?

Because that function will only copy the weights to additional input channels if the first convolution layer of the weights has 3 input channels. Otherwise it raises a NotImplementedError exception and randomly initializes the weights. I linked to the timm implementation in my comment on issue #2289 .

@adamjstewart
Copy link
Collaborator

Will try to review next week, this week is quite busy. Apologies for the wait!

@adamjstewart adamjstewart modified the milestones: 0.6.1, 0.6.2 Oct 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
models Models and pretrained weights
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants