Skip to content

HieuPhan33/CVPR2024_MAVL

Repository files navigation

CVPR2024 - Decomposing Disease Descriptions for Enhanced Pathology Detection: A Multi-Aspect Vision-Language Pre-training Framework

Introduction

Welcome to the official implementation code for "Decomposing Disease Descriptions for Enhanced Pathology Detection: A Multi-Aspect Vision-Language Matching Framework", accepted at CVPR2024 🎉

Arxiv Version

This work leverages LLM 🤖 to decompose disease descriptions into a set of visual aspects. Our visual aspect vision-language pre-training framework, dubbed MAVL, achieves the state-of-the-art performance across 7 datasets for zero-shot and low-shot fine-tuning settings for disease classification and segmentation.

📝 Citation

If you find our work useful, please cite our paper.

@article{phan2024decomposing,
    title={Decomposing Disease Descriptions for Enhanced Pathology Detection: A Multi-Aspect  Vision-Language Pre-training Framework}, 
    author={Vu Minh Hieu Phan and Yutong Xie and Yuankai Qi and Lingqiao Liu and Liyang Liu and Bowen Zhang and Zhibin Liao and Qi Wu and Minh-Son To and Johan W. Verjans},
    year={2024},
    journal={arXiv preprint arXiv:2403.07636},
}

Comparisons with SOTA image-text pre-training models under zero-shot classification on 5 datasets.

Dataset CheXpert ChestXray-14 PadChest-seen RSNA Pneumonia SIIM-ACR
Method AUC F1 ACC AUC F1 ACC AUC F1 ACC AUC F1 ACC AUC F1 ACC
ConVIRT 52.10 35.61 57.43 53.15 12.38 57.88 63.72 14.56 73.47 79.21 55.67 75.08 64.25 42.87 53.42
GLoRIA 54.84 37.86 60.70 55.92 14.20 59.47 64.09 14.83 73.86 70.37 48.19 70.54 54.71 40.39 47.15
BioViL 60.01 42.10 66.13 57.82 15.64 61.33 60.35 10.63 70.48 84.12 54.59 74.43 70.28 46.45 68.22
BioViL-T 70.93 47.21 69.96 60.43 17.29 62.12 65.78 15.37 77.52 86.03 62.56 80.04 75.56 60.18 73.72
CheXzero 87.90 61.90 81.17 66.99 21.99 65.38 73.24 19.53 83.49 85.13 61.49 78.34 84.60 65.97 77.34
MedKLIP 87.97 63.67 84.32 72.33 24.18 79.40 77.87 26.63 92.44 85.94 62.57 79.97 89.79 72.73 83.99
MAVL (Proposed) 90.13 65.47 86.44 73.57 26.25 82.77 78.79 28.48 92.56 86.31 65.26 81.28 92.04 77.95 87.14

💡 Download Necessary Files

To get started, install the gdown library:

pip install -U --no-cache-dir gdown --pre

Then, run bash download.sh

The MIMIC-CXR2 needs to be downloaded from physionet.

🚀 Library Installation

We have pushed the docker image with necessary environments. You can directly create a docker container using our docker image:

docker pull stevephan46/mavl:latest
docker run --runtime=nvidia --name mavl -it -v /your/data/root/folder:/data --shm-size=4g stevephan46/mavl:latest

You may need to reinstall opencv-python, as there is some conflicting problem with the docker environment pip install opencv-python==4.2.0.32

If you prefer manual installation over docker, please run the following installation:

pip install -r requirements.txt
pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113
pip install opencv-python==4.2.0.32

🤖 LLM Disease's Visual Concept Generation

The script to generate diseases' visual aspects using LLM - GPT can be found here.

🔥 Pre-train:

Our pre-train code is given in Pretrain.

  • Run download.sh to download necessary files

  • Modify the path in config file configs/MAVL_resnet.yaml, and python train_mavl.py to pre-train.

  • Run accelerate launch --multi_gpu --num_processes=4 --num_machines=1 --num_cpu_threads_per_process=8 train_MAVL.py --root /data/2019.MIMIC-CXR-JPG/2.0.0 --config configs/MAVL_resnet.yaml --bs 124 --num_workers 8

Note: The reported results in our paper are obtained by pre-training on 4 x A100 for 60 epochs. We provided the checkpoints here. We found that ckpts at later stage (checkpoint_full_46.pth) yields higher zero-shot classification accuracy. Ckpt at earlier stage (checkpoint_full_40.pth) yields more stable accuracy on visual grounding.

We also conducted a lighter pre-training schedule with 2 x A100 for 40 epochs using mixed precision training, achieving similar zero-shot classification results. Checkpoint for this setup is also available here.

accelerate launch --multi_gpu --num_processes=2 --num_machines=1 --num_cpu_threads_per_process=8 --mixed_precision=fp16 train_MAVL.py --root /data/2019.MIMIC-CXR-JPG/2.0.0 --config configs/MAVL_short.yaml --bs 124 --num_workers 8

📦 Downstream datasets:

Links to download downstream datasets are:

  • CheXpert.
  • ChestXray-14.
  • PadChest.
  • RSNA - Download images from initial annotations.
  • SIIM.
  • COVIDx-CXR-2 - The official link on Kaggle is down. The publicly available expanded version, called COVIDx-CXR4 is released here. They encompass COVIDx-CXR-2 as subset. Please use our dataset csv splits to reproduce the results on COVIDx-CXR-2 subset version.
  • Covid Rural - The official link includes raw DICOM datasets. We use preprocessed data provided here.

🌟 Quick Start:

Check this link to download MAVL checkpoints. It can be used for all zero-shot && finetuning tasks

  • Zero-Shot Classification:

    We give examples in Sample_Zero-Shot_Classification. Modify the path, and test our model by python test.py --config configs/dataset_name_mavl.yaml

  • Zero-Shot Grounding:

    We give examples in Sample_Zero-Shot_Grounding. Modify the path, and test our model by python test.py

  • Finetuning:

    We give segmentation and classification finetune code on in Sample_Finetuning_SIIMACR. Modify the path, and finetune our model by python I1_classification/train_res_ft.py --config configs/dataset_name_mavl.yaml or python I2_segementation/train_res_ft.py --config configs/dataset_name_mavl.yaml

🙏 Acknowledgement

Our code is built upon https://github.com/MediaBrain-SJTU/MedKLIP. We thank the authors for open-sourcing their code.

Feel free to reach out if you have any questions or need further assistance!