Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace encoder_weights: str | None = "imagenet" with pretrained: bool=True #49

Open
DimitrisMantas opened this issue Jul 6, 2024 · 2 comments

Comments

@DimitrisMantas
Copy link

DimitrisMantas commented Jul 6, 2024

The encoder_weights parameter in the model initializers is a bit ambiguous/strange, especially given its lax type hint.

At least to me, it looked like you could pass strings to represent other weights (e.g., “coco”, “instagram”, “ssl”, etc.) at first glance, but it actually turns out that any non-null argument results in the timm weights for the particular encoder being loaded. This is because timm really expects a boolean flag for this and --- by extension --- torchseg.encoders.get_encoder() only performs the appropriate checks required to comply:

# Timm Encoders
else:
if name.split(".")[0] in TIMM_ENCODERS:
encoder = TimmEncoder(
name=name,
in_channels=in_channels,
depth=depth,
indices=indices,
output_stride=output_stride,
pretrained=weights is not None,
**kwargs,
)
elif name.split(".")[0] in TIMM_VIT_ENCODERS:
encoder = TimmViTEncoder(
name=name,
in_channels=in_channels,
depth=depth,
indices=indices,
pretrained=weights is not None,
scale_factors=scale_factors,
**kwargs,
)

I understand that this had some value in smp because it had multiple weights for some models, but timm has only one, so maybe it would make sense to change our API to reflect this.

@DimitrisMantas
Copy link
Author

DimitrisMantas commented Jul 6, 2024

I just saw that right above the block I referenced, the value of encoder_weights is actually used to look-up MixTransformer weights. I understand these are pulled from smp, but they are all ImageNet weights...

@isaaccorley
Copy link
Owner

This was left in for backwards compatibility. I agree with this though. Particularly because with timm you load the weights through the encoder/model name like vit_b16_224.mae. We can likely add a new pretrained arg and then deprecate the weights arg.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants