Project Page | Paper | Online Demo | Video
Official implementation of ScribblePrompt: Fast and Flexible Interactive Segmentation for any Biomedical Image accepted at ECCV 2024
Hallee E. Wong, Marianne Rakic, John Guttag, Adrian V. Dalca
- (2024-12-31) Released example training code
- (2024-12-12) Released full prompt simulation code
- (2024-07-01) ScribblePrompt has been accepted to ECCV 2024!
- (2024-06-17) ScribblePrompt won the Bench-to-Bedside Award at the CVPR 2024 DCAMI Workshop!
- (2024-04-16) Released MedScribble -- a diverse dataset of segmentation tasks with scribble annotations
- (2024-04-15) An updated version of the paper is on arXiv!
- (2024-04-14) Added Google Colab Tutorial
- (2024-01-19) Released scribble simulation code
- (2023-12-15) Released model code and weights
- (2023-12-12) Paper and online demo released
ScribblePrompt is an interactive segmentation tool that enables users to segment unseen structures in medical images using scribbles, clicks, and bounding boxes.
- Interactive online demo on Hugging Face Spaces
- See Installation and Getting Started for how to run the Gradio demo locally
- Jupyter notebook colab tutorial using pre-trained models
- Jupyter notebook tutorials on training and the prompt generator code
We provide checkpoints for two versions of ScribblePrompt:
-
ScribblePrompt-UNet with an efficient fully-convolutional architecture
-
ScribblePrompt-SAM based on the Segment Anything Model
Both models have been trained with iterative scribbles, click, and bounding box interactions on a diverse collection of 65 medical imaging datasets with both real and synthetic labels.
We release MedScribble, a dataset of multi-annotator scribble annotations for diverse biomedical image segmentation tasks, under ./MedScribble
. See the readme for more info and ./MedScribble/tutorial.ipynb
for a preview of the data.
You can install scribbleprompt
in two ways:
- With pip:
pip install git+https://github.com/halleewong/ScribblePrompt.git
- Manually: cloning it and installing dependencies
git clone https://github.com/halleewong/ScribblePrompt
python -m pip install -r ./ScribblePrompt/requirements.txt
export PYTHONPATH="$PYTHONPATH:$(realpath ./ScribblePrompt)"
The following optional dependencies are necessary for the local demo:
pip install gradio==3.40.1
First, download the model checkpoints to ./checkpoints
.
To run an interactive demo locally:
python demos/app.py
To instantiate ScribblePrompt-UNet and make a prediction:
from scribbleprompt import ScribblePromptUNet
sp_unet = ScribblePromptUNet()
mask = sp_unet.predict(
image, # (B, 1, H, W)
point_coords, # (B, n, 2)
point_labels, # (B, n)
scribbles, # (B, 2, H, W)
box, # (B, n, 4)
mask_input, # (B, 1, H, W)
) # -> (B, 1, H, W)
To instantiate ScribblePrompt-SAM and make a prediction:
from scribbleprompt import ScribblePromptSAM
sp_sam = ScribblePromptSAM()
mask, img_features, low_res_logits = sp_sam.predict(
image, # (B, 1, H, W)
point_coords, # (B, n, 2)
point_labels, # (B, n)
scribbles, # (B, 2, H, W)
box, # (B, n, 4)
mask_input, # (B, 1, 256, 256)
) # -> (B, 1, H, W), (B, 16, 256, 256), (B, 1, 256, 256)
For best results, image
should have spatial dimensions
For ScribblePrompt-UNet, mask_input
should be the logits from the previous prediction. For ScribblePrompt-SAM, mask_input
should be low_res_logits
from the previous prediction.
Note: our training code requires the pylot library. The inference code above does not. We recommend installing via pip:
pip install git+https://github.com/JJGO/pylot.git@87191921033c4391546fd88c5f963ccab7597995
The configuration settings for training are controlled by yaml config files. We provide two example configs in ./configs
for fine-tuning from the pre-trained ScribblePrompt-UNet weights as well as training from scratch on an example dataset.
To fine-tune ScribblePrompt-UNet from the pre-trained weights:
python scribbleprompt/experiment/unet.py -config finetune_unet.yaml
To train a model from scratch:
python scribbleprompt/experiment/unet.py -config train_unet.yaml
For a more in-depth tutorial see ./notebooks/training.ipynb
.
- Release Gradio demo
- Release model code and weights
- Release jupyter notebook tutorial
- Release scribble simulation code
- Release MedScribble dataset
- Release training code
- Release segmentation labels collected using ScribblePrompt
-
Our training code builds on the
pylot
library for deep learning experiment management. We also make use of data augmentation code originally developed for UniverSeg. Thanks to @JJGO for sharing this code! -
We use functions from voxsynth for applying random deformations during scribble simulation
-
Code for ScribblePrompt-SAM builds on Segment Anything. Thanks to Meta AI for open-sourcing the model.
If you find our work or any of our materials useful, please cite our paper:
@article{wong2024scribbleprompt,
title={ScribblePrompt: Fast and Flexible Interactive Segmentation for Any Biomedical Image},
author={Hallee E. Wong and Marianne Rakic and John Guttag and Adrian V. Dalca},
journal={European Conference on Computer Vision (ECCV)},
year={2024},
}
Code for this project is released under the Apache 2.0 License