Skip to content

Commit

Permalink
update examples
Browse files Browse the repository at this point in the history
  • Loading branch information
Theodore Zhao committed Nov 18, 2024
1 parent a3ff9e1 commit 0621999
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 15 deletions.
33 changes: 27 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,14 @@ pip install -r assets/requirements/requirements.txt
BiomedParseData was created from preprocessing publicly available biomedical image segmentation datasets. Check a subset of our processed datasets on HuggingFace: https://huggingface.co/datasets/microsoft/BiomedParseData. For the source datasets, please check the details here: [BiomedParseData](assets/readmes/DATASET.md). As a quick start, we've samples a tiny demo dataset at biomedparse_datasets/BiomedParseData-Demo

## Model Checkpoints
We host our model checkpoints on HuggingFace here: https://huggingface.co/microsoft/BiomedParse. Please download and put the model in the pretrained folder when runing the code. Expect future updates of the model as we are making it more robust and powerful based on feedbacks from the community. We recomment using the latest version of the model.
We host our model checkpoints on HuggingFace here: https://huggingface.co/microsoft/BiomedParse.
Step 1. Create pretrained model folder
```
mkdir pretrained
```
Step 2. Download model checkpoint and put the model in the pretrained folder when runing the code. Change file name to biomed_parse.pt

Expect future updates of the model as we are making it more robust and powerful based on feedbacks from the community. We recomment using the latest version of the model.

## Finetune on Your Own Data
While BiomedParse can take in arbitrary image and text prompt, it can only reasonably segment the targets that it has learned during pretraining! If you have a specific segmentation task that the latest checkpint doesn't do well, here is the instruction on how to finetune it on your own data.
Expand Down Expand Up @@ -118,6 +125,7 @@ from utils.distributed import init_distributed
from utils.arguments import load_opt_from_config_files
from utils.constants import BIOMED_CLASSES
from inference_utils.inference import interactive_infer_image
import numpy as np

# Build model config
def parse_option():
Expand All @@ -142,19 +150,32 @@ with torch.no_grad():
### Segmentation On Example Images
```sh
# RGB image input of shape (H, W, 3). Currently only batch size 1 is supported.
image = Image.open('examples/Part_3_226_pathology_breast.png', formats=['png'])
image = Image.open('examples/Part_1_516_pathology_breast.png', formats=['png'])
image = image.convert('RGB')
# text prompts querying objects in the image. Multiple ones can be provided.
prompts = ['neoplastic cells in breast pathology', 'inflammatory cells']
prompts = ['neoplastic cells', 'inflammatory cells']

pred_mask, pred_text = interactive_infer_image(model, image, prompts)
# load ground truth mask
gt_masks = []
for prompt in prompts:
gt_mask = Image.open(f"examples/Part_1_516_pathology_breast_{prompt.replace(' ', '+')}.png", formats=['png'])
gt_mask = 1*(np.array(gt_mask.convert('RGB'))[:,:,0] > 0)
gt_masks.append(gt_mask)

pred_mask = interactive_infer_image(model, image, prompts)

# prediction with ground truth mask
for i, pred in enumerate(pred_mask):
gt = gt_masks[i]
dice = (1*(pred>0.5) & gt).sum() * 2.0 / (1*(pred>0.5).sum() + gt.sum())
print(f'Dice score for {prompts[i]}: {dice:.4f}')
```
<!--
Detection and recognition inference code are provided in `inference_utils/output_processing.py`.
- `check_mask_stats()`: Outputs p-value for model-predicted mask for detection.
- `combine_masks()`: Combines predictions for non-overlapping masks. -->
- `combine_masks()`: Combines predictions for non-overlapping masks.
Expand Down
25 changes: 18 additions & 7 deletions example_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from utilities.distributed import init_distributed
from utilities.arguments import load_opt_from_config_files
from utilities.constants import BIOMED_CLASSES
import numpy as np

from inference_utils.inference import interactive_infer_image

Expand All @@ -32,12 +33,22 @@ def parse_option():

# Load image and run inference
# RGB image input of shape (H, W, 3). Currently only batch size 1 is supported.
image = Image.open('examples/Part_3_226_pathology_breast.png', formats=['png'])
image = Image.open('examples/Part_1_516_pathology_breast.png', formats=['png'])
image = image.convert('RGB')
# text prompts querying objects in the image. Multiple ones can be provided.
prompts = ['neoplastic cells in breast pathology', 'inflammatory cells']

pred_mask, pred_text = interactive_infer_image(model, image, prompts)

# show prediction stats
print(pred_mask.shape, pred_mask.sum(axis=(1,2)), pred_mask.min(axis=(1,2)), pred_mask.max(axis=(1,2)), pred_text)
prompts = ['neoplastic cells', 'inflammatory cells']

# load ground truth mask
gt_masks = []
for prompt in prompts:
gt_mask = Image.open(f"examples/Part_1_516_pathology_breast_{prompt.replace(' ', '+')}.png", formats=['png'])
gt_mask = 1*(np.array(gt_mask.convert('RGB'))[:,:,0] > 0)
gt_masks.append(gt_mask)

pred_mask = interactive_infer_image(model, image, prompts)

# prediction with ground truth mask
for i, pred in enumerate(pred_mask):
gt = gt_masks[i]
dice = (1*(pred>0.5) & gt).sum() * 2.0 / (1*(pred>0.5).sum() + gt.sum())
print(f'Dice score for {prompts[i]}: {dice:.4f}')
Binary file added examples/Part_1_516_pathology_breast.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 1 addition & 2 deletions inference_utils/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,8 @@ def interactive_infer_image(model, image, prompts):
# interpolate mask to ori size
pred_mask_prob = F.interpolate(pred_masks_pos[None,], image_size[-2:], mode='bilinear')[0,:,:data['height'],:data['width']].sigmoid().cpu().numpy()
pred_masks_pos = (1*(pred_mask_prob > 0.5)).astype(np.uint8)
texts = [all_classes[c] for c in pred_class]

return pred_mask_prob, texts
return pred_mask_prob



Expand Down

0 comments on commit 0621999

Please sign in to comment.