Augmenting Efficient Surgical Instrument Segmentation in Video with Point Tracking and Segment Anything (SIS-PT-SAM)
This work won the Outstanding Paper Award at MICCAI 2024 AE-CAI Workshop!
The work was accepted by MICCAI 2024 AE-CAI Workshop!
PyTorch implementation of the SIS-PT-SAM. Inference speed achieves 25+/80+ FPS on single RTX 4060/4090 GPU. Use point prompts for full fine-tuning MobileSAM.
[arXiv]
- Python 3.11
- torch 2.1.2
- torchvision 0.16.2
Create a Conda environment:
conda create --name sis-pt-sam python=3.11
Activate the Conda environment and run the following command to install environment
pip install -r requirements.txt
Links to the publicly available dataset used in this work:
- EndoVis 2015
- EndoVis 2017
- EndoVis 2018
- ROBUST-MIS 2019
- AutoLaparo
- UCL dVRK
- CholecSeg8k
- SAR-RARP50
- STIR
Please reformat each dataset according to the following top-level directory layout.
.
├── ...
├── train
│ ├── imgs
│ │ ├──000000.png
│ │ ├──000001.png
│ │ └──...
│ └── gts
│ ├──000000.png
│ ├──000001.png
│ └──...
├── val
│ ├── imgs
│ │ ├──000000.png
│ │ ├──000001.png
│ │ └──...
│ └── gts
│ ├──000000.png
│ ├──000001.png
│ └──...
└──...
Download checkpoints of MobileSAM, CoTracker, and Light HQ-SAM. Put them into ./ckpts
If use single GPU for training, run:
python train.py -i ./data/[dataset]/train/ -v ./data/[dataset]/val/ --sam-ckpt ./ckpt/mobile_sam.pt --work-dir [path of the training results] --max-epochs 100 --data-aug --freeze-prompt-encoder --batch-size 4 --learn-rate 1e-5 --dataset [dataset]
For example:
python train.py -i /data/CholecSeg8k/train/ -v /data/CholecSeg8k/val/ --train-from-scratch --work-dir ./results/exp_cholecseg8k --max-epochs 100 --data-aug --freeze-prompt-encoder --batch-size 4 --learn-rate 1e-5 --dataset cholecseg
If use multi GPU for training, just add --multi-gpu
and replace the device_ids
in line 82 of the train.py
as the GPU you would like to use:
if args.multi_gpu:
surgicaltool_sam = nn.DataParallel(surgicaltool_sam, device_ids=[0,1,2,3])
We need to prepare the first frame and the corresponding mask of the video. If there are more than one tool to segment, please put masks of each tool into a folder.
Then use online_demo.py
to run the online demo for a video.
python online_demo.py --video_path [video path] --tracker cotracker --sam_type finetune --tool_number 2 --first_frame_path [path of the first frame of the video] --first_mask_path [path of the first frame mask of the video] --mask_dir_path [folder that contains the mask of each tool in first frame] --save_demo --mode kmedoids --add_support_grid --sam-ckpt ./ckpts/[checkpoint file]
If you have any problem using this code then create an issue in this repository or contact me at [email protected]
This project is licensed under the MIT License
Thanks to the following awesome work for the inspiration, code snippets, etc.