From e5db918cee31b1714f64458510b5fbfea9ee9aad Mon Sep 17 00:00:00 2001 From: Naoto Usuyama Date: Mon, 18 Nov 2024 15:42:41 -0800 Subject: [PATCH] Update Inference section --- README.md | 135 +++++++++++++++++++++++++++++------------------------- 1 file changed, 73 insertions(+), 62 deletions(-) diff --git a/README.md b/README.md index 9d8c305..5b74776 100644 --- a/README.md +++ b/README.md @@ -46,8 +46,81 @@ Step 2. Download model checkpoint and put the model in the pretrained folder whe Expect future updates of the model as we are making it more robust and powerful based on feedbacks from the community. We recomment using the latest version of the model. +## Running Inference with BiomedParse + +We’ve streamlined the process for running inference using BiomedParse. Below are details and resources to help you get started. + +### How to Run Inference +To perform inference with BiomedParse, use the provided example code and resources: + +- **Inference Code**: Use the example inference script in `example_prediction.py`. +- **Sample Images**: Load and test with the provided example images located in the `examples` directory. +- **Model Configuration**: The model settings are defined in `configs/biomedparse_inference.yaml`. + +### Example Notebooks + +We’ve included sample notebooks to guide you through running inference with BiomedParse: + +- **DICOM Inference Example**: Check out the `inference_examples_DICOM.ipynb` notebook for example using DICOM images. +- You can also try a quick online demo: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googlecolab/colabtools/blob/master/notebooks/colab-github-demo.ipynb) + +### Model Setup +```sh +from PIL import Image +import torch +from modeling.BaseModel import BaseModel +from modeling import build_model +from utilities.distributed import init_distributed +from utilities.arguments import load_opt_from_config_files +from utilities.constants import BIOMED_CLASSES +from inference_utils.inference import interactive_infer_image +import numpy as np + +# Build model config +opt = load_opt_from_config_files(["configs/biomedparse_inference.yaml"]) +opt = init_distributed(opt) + +# Load model from pretrained weights +pretrained_pth = 'pretrained/biomed_parse.pt' + +model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda() +with torch.no_grad(): + model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(BIOMED_CLASSES + ["background"], is_eval=True) +``` + +### Segmentation On Example Images +```sh +# RGB image input of shape (H, W, 3). Currently only batch size 1 is supported. +image = Image.open('examples/Part_1_516_pathology_breast.png', formats=['png']) +image = image.convert('RGB') +# text prompts querying objects in the image. Multiple ones can be provided. +prompts = ['neoplastic cells', 'inflammatory cells'] + +# load ground truth mask +gt_masks = [] +for prompt in prompts: + gt_mask = Image.open(f"examples/Part_1_516_pathology_breast_{prompt.replace(' ', '+')}.png", formats=['png']) + gt_mask = 1*(np.array(gt_mask.convert('RGB'))[:,:,0] > 0) + gt_masks.append(gt_mask) + +pred_mask = interactive_infer_image(model, image, prompts) + +# prediction with ground truth mask +for i, pred in enumerate(pred_mask): + gt = gt_masks[i] + dice = (1*(pred>0.5) & gt).sum() * 2.0 / (1*(pred>0.5).sum() + gt.sum()) + print(f'Dice score for {prompts[i]}: {dice:.4f}') +``` + + +Detection and recognition inference code are provided in `inference_utils/output_processing.py`. + +- `check_mask_stats()`: Outputs p-value for model-predicted mask for detection. +- `combine_masks()`: Combines predictions for non-overlapping masks. + ## Finetune on Your Own Data While BiomedParse can take in arbitrary image and text prompt, it can only reasonably segment the targets that it has learned during pretraining! If you have a specific segmentation task that the latest checkpint doesn't do well, here is the instruction on how to finetune it on your own data. + ### Raw Image and Annotation BiomedParse expects images and ground truth masks in 1024x1024 PNG format. For each dataset, put the raw image and mask files in the following format ``` @@ -112,68 +185,6 @@ bash assets/scripts/eval.sh ``` This will continue evaluate the model on the test datasets you specified in configs/biomed_seg_lang_v1.yaml. We put BiomedParseData-Demo as the default. You can add any other datasets in the list. -## Run Inference -Example inference code is provided in `example_prediction.py`. We provided example images in `examples` to load from. Model configuration is provided in `configs/biomedparse_inference.yaml`. - -### Example Notebooks -Check our inference examples for DICOM images at inference_examples_DICOM.ipynb. - -### Model Setup -```sh -from PIL import Image -import torch -from modeling.BaseModel import BaseModel -from modeling import build_model -from utilities.distributed import init_distributed -from utilities.arguments import load_opt_from_config_files -from utilities.constants import BIOMED_CLASSES -from inference_utils.inference import interactive_infer_image -import numpy as np - -# Build model config -opt = load_opt_from_config_files(["configs/biomedparse_inference.yaml"]) -opt = init_distributed(opt) - -# Load model from pretrained weights -pretrained_pth = 'pretrained/biomed_parse.pt' - -model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda() -with torch.no_grad(): - model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(BIOMED_CLASSES + ["background"], is_eval=True) -``` - -### Segmentation On Example Images -```sh -# RGB image input of shape (H, W, 3). Currently only batch size 1 is supported. -image = Image.open('examples/Part_1_516_pathology_breast.png', formats=['png']) -image = image.convert('RGB') -# text prompts querying objects in the image. Multiple ones can be provided. -prompts = ['neoplastic cells', 'inflammatory cells'] - -# load ground truth mask -gt_masks = [] -for prompt in prompts: - gt_mask = Image.open(f"examples/Part_1_516_pathology_breast_{prompt.replace(' ', '+')}.png", formats=['png']) - gt_mask = 1*(np.array(gt_mask.convert('RGB'))[:,:,0] > 0) - gt_masks.append(gt_mask) - -pred_mask = interactive_infer_image(model, image, prompts) - -# prediction with ground truth mask -for i, pred in enumerate(pred_mask): - gt = gt_masks[i] - dice = (1*(pred>0.5) & gt).sum() * 2.0 / (1*(pred>0.5).sum() + gt.sum()) - print(f'Dice score for {prompts[i]}: {dice:.4f}') -``` - - -Detection and recognition inference code are provided in `inference_utils/output_processing.py`. - -- `check_mask_stats()`: Outputs p-value for model-predicted mask for detection. -- `combine_masks()`: Combines predictions for non-overlapping masks. - - -