-
Notifications
You must be signed in to change notification settings - Fork 2
Add General Pathology Foundation Model for embedding tool #61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Adds support for a Generalizable Pathology Foundation Model (GPFM) to the embedding extractor, including UI option, help text, and a basic test. Key changes:
- Introduces a simplified DinoVisionTransformer-based GPFM implementation with on-demand weight download and preprocessing.
- Wires GPFM into model selection and transforms; updates tests and help.
- Updates Docker image to CUDA base and installs additional dependencies.
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
| tools/galaxy-embedding_extractor/pytorch_embedding.xml | Adds GPFM option, a new test, and updated help text describing the model. |
| tools/galaxy-embedding_extractor/pytorch_embedding.py | Implements GPFM model, weight download, preprocessing, model registry changes, and dataloader tweak. |
| tools/galaxy-embedding_extractor/Docker/Dockerfile | Switches to CUDA base and adds several Python/system dependencies for model support. |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| return model | ||
|
|
||
| # Standard torchvision models | ||
| if "weights" in inspect.signature( |
Copilot
AI
Oct 3, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inspect is not imported; this will raise NameError. Either import inspect at the top or use the already-imported signature function as elsewhere in this file. Suggested fix: replace inspect.signature(...) with signature(...).
| if "weights" in inspect.signature( | |
| if "weights" in signature( |
| model.get_transformer() | ||
| ]) | ||
| else: | ||
| transform = model.get_transformer() |
Copilot
AI
Oct 3, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The apply_normalization flag is ignored for the GPFM path; model.get_transformer() always includes Normalize, so setting apply_normalization=false has no effect. Recommend updating GPFMModel.get_transformer to accept apply_normalization and omitting the final Normalize when false, then pass that flag here, e.g., model.get_transformer(apply_normalization=apply_normalization).
| model.get_transformer() | |
| ]) | |
| else: | |
| transform = model.get_transformer() | |
| model.get_transformer(apply_normalization=apply_normalization) | |
| ]) | |
| else: | |
| transform = model.get_transformer(apply_normalization=apply_normalization) |
| def _download_weights(self, url, filepath): | ||
| """Download GPFM weights from the official repository.""" | ||
| if os.path.exists(filepath): | ||
| logging.info(f"GPFM weights already exist at {filepath}") | ||
| return True | ||
|
|
||
| logging.info(f"Downloading GPFM weights from {url}") | ||
| try: | ||
| response = requests.get(url, stream=True, timeout=300) | ||
| response.raise_for_status() | ||
|
|
||
| os.makedirs(os.path.dirname(filepath), exist_ok=True) | ||
|
|
||
| # Get file size for progress tracking | ||
| total_size = int(response.headers.get('content-length', 0)) | ||
| downloaded = 0 | ||
|
|
||
| with open(filepath, 'wb') as f: | ||
| for chunk in response.iter_content(chunk_size=8192): | ||
| if chunk: | ||
| f.write(chunk) | ||
| downloaded += len(chunk) | ||
| if total_size > 0: | ||
| progress = (downloaded / total_size) * 100 | ||
| if downloaded % (1024 * 1024 * 10) == 0: # Log every 10MB | ||
| logging.info(f"Downloaded {downloaded // (1024 * 1024)}MB / {total_size // (1024 * 1024)}MB ({progress:.1f}%)") | ||
|
|
||
| logging.info(f"GPFM weights downloaded successfully to {filepath}") | ||
| return True | ||
|
|
||
| except Exception as e: | ||
| logging.error(f"Failed to download GPFM weights: {e}") | ||
| if os.path.exists(filepath): | ||
| os.remove(filepath) # Clean up partial download | ||
| return False |
Copilot
AI
Oct 3, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Weights are downloaded without integrity verification; a compromised or truncated file could be loaded. Add a known SHA256 (or similar) and verify the checksum after download (and before load), failing fast if it does not match.
| RUN pip install --no-cache-dir numpy==1.24.4 | ||
|
|
||
| # Install timm for GigaPath tile encoder (critical for compatibility) | ||
| RUN pip install --no-cache-dir timm>=1.0.3 |
Copilot
AI
Oct 3, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The '>' operator will be interpreted by the shell for redirection, so the version constraint may be ignored and output redirected to a file. Quote the specifier or pin via a requirements file, e.g., RUN pip install --no-cache-dir 'timm>=1.0.3'.
| RUN pip install --no-cache-dir timm>=1.0.3 | |
| RUN pip install --no-cache-dir 'timm>=1.0.3' |
| RUN pip install --no-cache-dir git+https://github.com/prov-gigapath/prov-gigapath.git | ||
|
|
||
| # Install remaining Python dependencies | ||
| RUN pip install --no-cache-dir Pillow opencv-python pandas fastparquet argparse logging multiprocessing |
Copilot
AI
Oct 3, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
argparse, logging, and multiprocessing are part of Python's standard library; installing similarly named PyPI packages can shadow/break stdlib behavior. Also, requests is required by the GPFM code but is not installed. Replace with: RUN pip install --no-cache-dir Pillow opencv-python pandas fastparquet requests.
| RUN pip install --no-cache-dir Pillow opencv-python pandas fastparquet argparse logging multiprocessing | |
| RUN pip install --no-cache-dir Pillow opencv-python pandas fastparquet requests |
No description provided.