Skip to content

Commit

Permalink
feat: support models from local sources (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
eloy-encord authored Mar 13, 2024
1 parent 76721a0 commit b603f2c
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 1 deletion.
1 change: 1 addition & 0 deletions clip_eval/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

class PROJECT_PATHS:
EMBEDDINGS = CACHE_PATH / "embeddings"
MODELS = CACHE_PATH / "models"
REDUCTIONS = CACHE_PATH / "reductions"


Expand Down
18 changes: 18 additions & 0 deletions clip_eval/models/local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from .CLIP_model import OpenCLIPModel


class LocalCLIPModel(OpenCLIPModel):
def __init__(
self,
title: str,
device: str | None = None,
*,
title_in_source: str,
cache_dir: str | None = None,
**kwargs,
) -> None:
super().__init__(title, device, title_in_source=title_in_source, cache_dir=cache_dir, **kwargs)

def _setup(self, **kwargs) -> None:
self.pretrained = (self._cache_dir / "checkpoint.pt").as_posix()
super()._setup(**kwargs)
6 changes: 5 additions & 1 deletion clip_eval/models/provider.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .CLIP_model import CLIPModel, ClosedCLIPModel, OpenCLIPModel, SiglipModel
from .local import LocalCLIPModel


class ModelProvider:
Expand Down Expand Up @@ -32,7 +33,7 @@ def list_model_names(self) -> list[str]:
title_in_source="wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M",
)
model_provider.register_model("fashion", ClosedCLIPModel, title_in_source="patrickjohncyh/fashion-clip")
model_provider.register_model("rscid", ClosedCLIPModel, title_in_source="flax-community/clip-rsicd")
model_provider.register_model("rsicd", ClosedCLIPModel, title_in_source="flax-community/clip-rsicd")
model_provider.register_model("street", ClosedCLIPModel, title_in_source="geolocal/StreetCLIP")

model_provider.register_model("apple", OpenCLIPModel, title_in_source="hf-hub:apple/DFN5B-CLIP-ViT-H-14")
Expand All @@ -42,3 +43,6 @@ def list_model_names(self) -> list[str]:

model_provider.register_model("siglip_small", SiglipModel, title_in_source="google/siglip-base-patch16-224")
model_provider.register_model("siglip_large", SiglipModel, title_in_source="google/siglip-large-patch16-256")

# Local sources
model_provider.register_model("rsicd-encord", LocalCLIPModel, title_in_source="ViT-B/32")

0 comments on commit b603f2c

Please sign in to comment.