Skip to content

Commit

Permalink
Overhaul 2.1 Remove Dependencies / Add Full Timm Support (#3)
Browse files Browse the repository at this point in the history
* remove dependencies and add full timm support

* update readme

* add unet head upsampling param

* fix flake8

* update readme

* make output_stride optional, not all timm models support an output stride arg
  • Loading branch information
isaaccorley authored Feb 27, 2024
1 parent f4cb90a commit 3401b84
Show file tree
Hide file tree
Showing 43 changed files with 5,699 additions and 4,076 deletions.
490 changes: 155 additions & 335 deletions README.md

Large diffs are not rendered by default.

Binary file added assets/pretrained_weights.webp
Binary file not shown.
8 changes: 2 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,9 @@ classifiers = [
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
dependencies = [
"einops>=0.7.0",
"timm>=0.9.12",
"torch>=1.13",
"torchvision>=0.14",
"einops>=0.3",
"pretrainedmodels==0.7.4",
"efficientnet-pytorch==0.7.1",
]
dynamic = ["version"]

Expand All @@ -50,7 +47,6 @@ style = [
tests = [
"pytest>=7.3",
"pytest-cov>=4",
"six>=1.16.0"
]
all = [
"torchseg[style,tests]",
Expand All @@ -71,7 +67,7 @@ exclude_lines = [

[tool.isort]
profile = "black"
known_first_party = ["docs", "tests", "torchseg", "train"]
known_first_party = ["tests", "torchseg", "train"]
skip_gitignore = true
color_output = true

Expand Down
7 changes: 2 additions & 5 deletions requirements/required.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,5 @@ setuptools==69.0.0

# install
einops==0.7.0
timm==0.9.2
torch==2.1.2
torchvision==0.16.2
pretrainedmodels==0.7.4
efficientnet-pytorch==0.7.1
timm==0.9.12
torch==2.1.2
4 changes: 1 addition & 3 deletions requirements/tests.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# tests
pytest==7.4.4
pytest-cov==4.1.0
mock==5.1.0
six==1.16.0
pytest-cov==4.1.0
40 changes: 40 additions & 0 deletions scripts/list_compatible_encoders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import json

import timm
from tqdm import tqdm

if __name__ == "__main__":
# Check for models that support `features_only=True``
works, fails = {}, []
for model in tqdm(timm.list_models()):
try:
m = timm.create_model(model, pretrained=False, features_only=True)
works[model] = dict(
indices=m.feature_info.out_indices,
channels=m.feature_info.channels(),
reduction=m.feature_info.reduction(),
module=m.feature_info.module_name(),
)
except RuntimeError:
fails.append(model)

with open("encoders_features_only_supported.json", "w") as f:
json.dump(works, f, indent=2)

# Check for models that support `get_intermediate_layers``
intermediate_layers_support = []
unsupported = []

for model in tqdm(fails):
m = timm.create_model(model, pretrained=False)
if hasattr(m, "get_intermediate_layers"):
intermediate_layers_support.append(model)
else:
unsupported.append(model)

with open("encoders_get_intermediate_layers_supported.json", "w") as f:
json.dump(intermediate_layers_support, f, indent=2)

# Save unsupported timm models
with open("encoders_unsupported.json", "w") as f:
json.dump(unsupported, f, indent=2)
Loading

0 comments on commit 3401b84

Please sign in to comment.