diff --git a/clip_eval/constants.py b/clip_eval/constants.py index b0828ce..183c2f8 100644 --- a/clip_eval/constants.py +++ b/clip_eval/constants.py @@ -12,6 +12,7 @@ class PROJECT_PATHS: EMBEDDINGS = CACHE_PATH / "embeddings" + MODELS = CACHE_PATH / "models" REDUCTIONS = CACHE_PATH / "reductions" diff --git a/clip_eval/models/local.py b/clip_eval/models/local.py new file mode 100644 index 0000000..6103b06 --- /dev/null +++ b/clip_eval/models/local.py @@ -0,0 +1,18 @@ +from .CLIP_model import OpenCLIPModel + + +class LocalCLIPModel(OpenCLIPModel): + def __init__( + self, + title: str, + device: str | None = None, + *, + title_in_source: str, + cache_dir: str | None = None, + **kwargs, + ) -> None: + super().__init__(title, device, title_in_source=title_in_source, cache_dir=cache_dir, **kwargs) + + def _setup(self, **kwargs) -> None: + self.pretrained = (self._cache_dir / "checkpoint.pt").as_posix() + super()._setup(**kwargs) diff --git a/clip_eval/models/provider.py b/clip_eval/models/provider.py index b4e8a24..bfd5a87 100644 --- a/clip_eval/models/provider.py +++ b/clip_eval/models/provider.py @@ -1,4 +1,5 @@ from .CLIP_model import CLIPModel, ClosedCLIPModel, OpenCLIPModel, SiglipModel +from .local import LocalCLIPModel class ModelProvider: @@ -32,7 +33,7 @@ def list_model_names(self) -> list[str]: title_in_source="wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M", ) model_provider.register_model("fashion", ClosedCLIPModel, title_in_source="patrickjohncyh/fashion-clip") -model_provider.register_model("rscid", ClosedCLIPModel, title_in_source="flax-community/clip-rsicd") +model_provider.register_model("rsicd", ClosedCLIPModel, title_in_source="flax-community/clip-rsicd") model_provider.register_model("street", ClosedCLIPModel, title_in_source="geolocal/StreetCLIP") model_provider.register_model("apple", OpenCLIPModel, title_in_source="hf-hub:apple/DFN5B-CLIP-ViT-H-14") @@ -42,3 +43,6 @@ def list_model_names(self) -> list[str]: model_provider.register_model("siglip_small", SiglipModel, title_in_source="google/siglip-base-patch16-224") model_provider.register_model("siglip_large", SiglipModel, title_in_source="google/siglip-large-patch16-256") + +# Local sources +model_provider.register_model("rsicd-encord", LocalCLIPModel, title_in_source="ViT-B/32")