Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move all core methods to hestcore #6

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions bin/extract_patch_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import sys; sys.path.append('../')
import argparse
import os
import logging
import os

import openslide
from tqdm import tqdm
from core.utils.utils import get_pixel_size

from core.preprocessing.conch_patch_embedder import TileEmbedder
from core.preprocessing.hest_modules.segmentation import TissueSegmenter
from core.preprocessing.hest_modules.wsi import get_pixel_size, OpenSlideWSI
from core.preprocessing.conch_patch_embedder import ConchTileEmbedder
from hestcore.wsi import OpenSlideWSI
from hestcore.segmentation import segment_tissue_deep

# Configure logger
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
Expand Down Expand Up @@ -38,8 +39,7 @@ def process(slide_dir, out_dir, patch_mag, patch_size):
os.makedirs(patch_emb_path, exist_ok=True)

# create tissue segmenter and tile embedder
segmenter = TissueSegmenter(save_path=seg_path, batch_size=64)
embedder = TileEmbedder(target_patch_size=patch_size, target_mag=patch_mag, save_path=out_dir)
embedder = ConchTileEmbedder(target_patch_size=patch_size, target_mag=patch_mag, save_path=out_dir)

for fn in tqdm(fnames):

Expand All @@ -49,13 +49,21 @@ def process(slide_dir, out_dir, patch_mag, patch_size):
fn_no_extension = os.path.splitext(fn)[0]

# 2. segment tissue
gdf_contours = segmenter.segment_tissue(
wsi=wsi,
pixel_size=pixel_size,
save_bn=fn_no_extension,
gdf_contours = segment_tissue_deep(
wsi,
pixel_size,
batch_size=64
)

# 3. extract patches and embeddings
# 3. save segmentation + visualization
os.makedirs(os.path.join(out_dir, 'geojson'), exist_ok=True)
os.makedirs(os.path.join(out_dir, 'jpeg'), exist_ok=True)
seg_name = fn_no_extension + '_tissue_vis.jpeg'
wsi.get_tissue_vis(gdf_contours).save(os.path.join(out_dir, 'jpeg', seg_name))
seg_name = fn_no_extension + '_tissue_mask.geojson'
gdf_contours.to_file(os.path.join(out_dir, 'geojson', seg_name), driver="GeoJSON")

# 4. extract patches and embeddings
embedder.embed_tiles(
wsi=wsi,
gdf_contours=gdf_contours,
Expand Down
90 changes: 30 additions & 60 deletions core/preprocessing/conch_patch_embedder.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from tqdm import tqdm
import numpy as np
import h5py
import os
from PIL import Image

import torch
from torch.utils.data import Dataset

import h5py
import numpy as np
import torch
import torchvision.transforms as transforms
from conch.open_clip_custom import create_model_from_pretrained

from hestcore.datasets import WSIPatcherDataset
# from core.preprocessing.hest_modules.wsi import WSIPatcher
from core.preprocessing.hest_modules.wsi import OpenSlideWSIPatcher, get_pixel_size
from hestcore.wsi import OpenSlideWSIPatcher
from PIL import Image
from torch.utils.data import Dataset
from tqdm import tqdm

from core.utils.utils import get_pixel_size, mag_to_px_size


def save_hdf5(output_fpath,
Expand Down Expand Up @@ -72,7 +74,7 @@ def collate_features(batch):
return features, coords


class TileEmbedder:
class ConchTileEmbedder:
def __init__(self,
model_name='conch_ViT-B-16',
model_repo='hf_hub:MahmoodLab/conch',
Expand Down Expand Up @@ -100,13 +102,23 @@ def embed_tiles(self, wsi, gdf_contours, fn) -> str:
patching_save_path = os.path.join(self.save_path, 'patches', f'{fn}_patches.png')
embedding_save_path = os.path.join(self.save_path, 'patch_embeddings', f'{fn}.h5')

dataset = TileDataset(
wsi=wsi,
gdf_contours=gdf_contours,
target_patch_size=self.target_patch_size,
target_mag=self.target_mag,
eval_transform=self.img_transforms,
save_path=patching_save_path)
dst_pixel_size = mag_to_px_size(self.target_mag)
src_pixel_size = get_pixel_size(wsi.img)

patcher = wsi.create_patcher(
self.target_patch_size,
src_pixel_size,
dst_pixel_size,
mask=gdf_contours,
pil=True
)

conch_transforms = transforms.Compose([
self.img_transforms,
transforms.Lambda(lambda x: torch.unsqueeze(x, 0))
])

dataset = WSIPatcherDataset(patcher, transform=conch_transforms)

dataloader = torch.utils.data.DataLoader(
dataset,
Expand All @@ -130,46 +142,4 @@ def embed_tiles(self, wsi, gdf_contours, fn) -> str:
}
save_hdf5(embedding_save_path, mode=mode, asset_dict=asset_dict)

return embedding_save_path


class TileDataset(Dataset):
def __init__(self, wsi, gdf_contours, target_patch_size, target_mag, eval_transform, save_path=None):
self.wsi = wsi
self.gdf_contours = gdf_contours
self.eval_transform = eval_transform

self.patcher = OpenSlideWSIPatcher(
wsi=wsi,
patch_size=target_patch_size,
src_pixel_size=get_pixel_size(wsi.img),
dst_pixel_size=self.mag_to_px_size(target_mag),
mask=gdf_contours,
coords_only=False,
)
self.patcher.save_visualization(path=save_path)

@staticmethod
def mag_to_px_size(mag):
if mag == 5: return 2.0
if mag == 10: return 1.0
if mag == 20: return 0.5
if mag == 40: return 0.25
else: raise ValueError('Magnification should be in [5, 10, 20, 40].')

# def _load_coords(self):
# with h5py.File(self.coords_h5_fpath, "r") as f:
# self.attr_dict = {k: dict(f[k].attrs) for k in f.keys() if len(f[k].attrs) > 0}
# self.coords = f['coords'][:]
# self.patch_size = f['coords'].attrs['patch_size']
# self.custom_downsample = f['coords'].attrs['custom_downsample']
# self.target_patch_size = int(self.patch_size) // int(self.custom_downsample) if self.custom_downsample > 1 else self.patch_size

def __len__(self):
return len(self.patcher)

def __getitem__(self, idx):
img, x, y = self.patcher[idx]
img = Image.fromarray(img, 'RGB')
img = self.eval_transform(img).unsqueeze(dim=0)
return img, (x, y)
return embedding_save_path
73 changes: 0 additions & 73 deletions core/preprocessing/hest_modules/SegDataset.py

This file was deleted.

Empty file.
Loading