Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to load BiomedCLIP from local path #772

Open
LIKP0 opened this issue Dec 25, 2023 · 4 comments
Open

How to load BiomedCLIP from local path #772

LIKP0 opened this issue Dec 25, 2023 · 4 comments

Comments

@LIKP0
Copy link

LIKP0 commented Dec 25, 2023

Hello all, thanks for your great work first!

I'm trying to do some research with BiomedCLIP, and I follow the instructions of example notebook:

from open_clip import create_model_from_pretrained, get_tokenizer # works on open-clip-torch>=2.23.0, timm>=0.9.8
model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

However, I met the similar problem with #724 . Since there is no network access on my server node, I download open_clip_pytorch_model.bin from HF BiomedCLIP repo, but I can't figure out the model name of BiomedCLIP:
model, preprocess = create_model_from_pretrained(model_name='ViT-B-16', pretrained="./open_clip_pytorch_model.bin")
will throw an model paramter mismatch error.

I searched a lot, but it seems like the architechture of BiomedCLIP does not present in the existing open_clip model list?

Could anyone help me about this? Thanks in advance!

@GuiQuQu
Copy link

GuiQuQu commented Jan 5, 2024

I have load another model named 'eva02_large_patch14_clip_224.merged2b_s4b_b131k' from local path ,I will upload my code in this. Hope this helps you.

from typing import List, Optional
import logging
from PIL import Image
import torch

import open_clip

HF_HUB_PREFIX = "hf-hub:"

logging.basicConfig(level=logging.INFO)


def get_cast_type(model) -> torch.dtype:
    if isinstance(model, torch.nn.Module):
        return next(model.parameters()).dtype
    else:
        return None


def get_cast_device(model) -> torch.device:
    if isinstance(model, torch.nn.Module):
        return next(model.parameters()).device
    else:
        return "cpu"

def main():
    model_name = "EVA02-L-14"
    cpkt_path = "../pretrain-model/eva02_large_patch14_clip_224.merged2b_s4b_b131k"
    cpkt_file = cpkt_path + "/open_clip_pytorch_model.bin"
    precision = "fp32"
    model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(
        model_name=model_name, pretrained=cpkt_file, device="cpu", precision=precision
    )
    print(type(model))
    print(next(model.parameters()))
    tokenizer = open_clip.get_tokenizer(HF_HUB_PREFIX + cpkt_path)
    print(tokenizer)
    image = preprocess_val(Image.open("../CLIP.png")).unsqueeze(0)
    text = tokenizer(["a diagram", "a dog", "a cat"])
    input_device = get_cast_device(model)
    image = image.to(input_device)
    text = text.to(input_device)
    with torch.no_grad(), torch.cuda.amp.autocast():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

        print("Label probs:", text_probs)


if __name__ == "__main__":
    main()

@jinxixiang
Copy link

jinxixiang commented Jan 10, 2024

I had the same issue. My server has no internet connection and this is painful. I found a workaround solution to this issue and may help you.

Step 1: download BiomedCLIP and its text encoder BiomedNLP-BiomedBERT-base-uncased-abstract and then upload to your server offline.

Step 2: modify the open_clip_config.json, and change these two line of code to your local directory:

Original:

"text_cfg": {
      "hf_model_name": "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract",
      "hf_tokenizer_name": "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract",
      "hf_proj_type": "mlp",
      "hf_pooler_type": "cls_last_hidden_state_pooler",
      "context_length": 256
    }

Modified:


"text_cfg": {
      "hf_model_name": "/your_local_path/to/microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract",
      "hf_tokenizer_name": "/your_local_path/to/microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract",
      "hf_proj_type": "mlp",
      "hf_pooler_type": "cls_last_hidden_state_pooler",
      "context_length": 256
    }

Step 3: Change the code of open_clip module. Go to the directory where you install open_clip_torch. In my case, I go to /opt/conda/lib/python3.8/site-packages/open_clip

Change the way to load models from hugginface download:

Original factory.py:

def create_model(
        model_name: str,
        pretrained: Optional[str] = None,
        precision: str = 'fp32',
        device: Union[str, torch.device] = 'cpu',
        jit: bool = False,
        force_quick_gelu: bool = False,
        force_custom_text: bool = False,
        force_patch_dropout: Optional[float] = None,
        force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
        force_preprocess_cfg: Optional[Dict[str, Any]] = None,
        pretrained_image: bool = False,
        pretrained_hf: bool = True,
        cache_dir: Optional[str] = None,
        output_dict: Optional[bool] = None,
        require_pretrained: bool = False,
        **model_kwargs,
):
    force_preprocess_cfg = force_preprocess_cfg or {}
    preprocess_cfg = asdict(PreprocessCfg())
    has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
    if has_hf_hub_prefix:
        model_id = model_name[len(HF_HUB_PREFIX):]
        checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
        config = _get_hf_config(model_id, cache_dir)
        preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg'])
        model_cfg = config['model_cfg']
        pretrained_hf = False  # override, no need to load original HF text weights
    else:
        model_name = model_name.replace('/', '-')  # for callers using old naming with / in ViT names
        checkpoint_path = None
        model_cfg = None

Modified factory.py:

def create_model(
        model_name: str,
        pretrained: Optional[str] = None,
        precision: str = 'fp32',
        device: Union[str, torch.device] = 'cpu',
        jit: bool = False,
        force_quick_gelu: bool = False,
        force_custom_text: bool = False,
        force_patch_dropout: Optional[float] = None,
        force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
        force_preprocess_cfg: Optional[Dict[str, Any]] = None,
        pretrained_image: bool = False,
        pretrained_hf: bool = True,
        cache_dir: Optional[str] = None,
        output_dict: Optional[bool] = None,
        require_pretrained: bool = False,
        **model_kwargs,
):
    force_preprocess_cfg = force_preprocess_cfg or {}
    preprocess_cfg = asdict(PreprocessCfg())
    has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
    if has_hf_hub_prefix:
        model_id = model_name[len(HF_HUB_PREFIX):]
       
       # -------------------- modified  --------------------#
        checkpoint_path = f"{cache_dir}/open_clip_pytorch_model.bin"
        config = json.load(open(f"{cache_dir}/open_clip_config.json"))

        preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg'])
        model_cfg = config['model_cfg']
        pretrained_hf = False  # override, no need to load original HF text weights
    else:
        model_name = model_name.replace('/', '-')  # for callers using old naming with / in ViT names
        checkpoint_path = None
        model_cfg = None

Step 4: change the get_tokenizer function in factory.py:

Original:


def get_tokenizer(
        model_name: str = '',
        context_length: Optional[int] = None,
        **kwargs,
):
    if model_name.startswith(HF_HUB_PREFIX):
        model_name = model_name[len(HF_HUB_PREFIX):]
        try:
            config = _get_hf_config(model_name)['model_cfg']
        except Exception:
            tokenizer = HFTokenizer(
                model_name,
                context_length=context_length or DEFAULT_CONTEXT_LENGTH,
                **kwargs,
            )
            return tokenizer
    else:
        config = get_model_config(model_name)
        assert config is not None, f"No valid model config found for {model_name}."

Modified:


def get_tokenizer(
        model_name: str = '',
        cache_dir: Optional[str] = None,
        context_length: Optional[int] = None,
        **kwargs,
):

    if model_name.startswith(HF_HUB_PREFIX):
        model_name = model_name[len(HF_HUB_PREFIX):]
        try:
            #config = _get_hf_config(model_name)['model_cfg']

            # modified
            config = json.load(open(os.path.join(cache_dir, 'open_clip_config.json')))['model_cfg']

        except Exception:
            tokenizer = HFTokenizer(
                model_name,
                context_length=context_length or DEFAULT_CONTEXT_LENGTH,
                **kwargs,
            )
            return tokenizer
    else:
        config = get_model_config(model_name)
        assert config is not None, f"No valid model config found for {model_name}."

Step4 : Now you can load model like this:

model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224',
                                                cache_dir='/local_path/to/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224', 
                          cache_dir='/local_path/to/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

@rom1504
Copy link
Collaborator

rom1504 commented Jan 10, 2024 via email

@LIKP0
Copy link
Author

LIKP0 commented Jan 11, 2024

Hello all, really appreciate for your help!

@jinxixiang I will try it later and respond to you, thanks!

@rom1504 Could you give me some detailed code about it?

I tried the pretrained arg but I cannot find the model name arg of BiomedCLIP, neither in open_clip.list_pretrained(). I guess the model name is ViT-B-16, but it doesn't work...

image

Learned from the tutorial I tried:

model, preprocess = create_model_from_pretrained(model_name='ViT-B-16', pretrained="/localpath/BiomedCLIP/open_clip_pytorch_model.bin")

but I get a missing key error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants