Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The Quick Start Code cannot be executed in mPLUG-Owl2 #195

Open
ppsmk388 opened this issue Dec 14, 2023 · 8 comments
Open

The Quick Start Code cannot be executed in mPLUG-Owl2 #195

ppsmk388 opened this issue Dec 14, 2023 · 8 comments

Comments

@ppsmk388
Copy link

When I run the following code:

The code from the Quick Start Code this page:

https://github.com/X-PLUG/mPLUG-Owl/tree/main/mPLUG-Owl2

import torch
from PIL import Image
from transformers import TextStreamer

from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from mplug_owl2.conversation import conv_templates, SeparatorStyle
from mplug_owl2.model.builder import load_pretrained_model
from mplug_owl2.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria

image_file = '' # Image Path
model_path = 'MAGAer13/mplug-owl2-llama2-7b'
query = "Describe the image."

model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False, load_4bit=False, device="cuda")

conv = conv_templates["mplug_owl2"].copy()
roles = conv.roles

I got the following feedback:

Traceback (most recent call last):
File "/home/kkk/.pycharm_helpers/pydev/_pydevd_bundle/pydevd_exec2.py", line 3, in Exec
exec(exp, global_vars, local_vars)
File "", line 1, in
File "/home/kkk/DB/libs/mPLUGOwl/mPLUG-Owl2/mplug_owl2/model/builder.py", line 106, in load_pretrained_model
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
File "/data/kkk/anaconda3/envs/va/lib/python3.9/site-packages/transformers/models/auto/auto_factory.py", line 566, in from_pretrained
return model_class.from_pretrained(
File "/data/kkk/anaconda3/envs/va/lib/python3.9/site-packages/transformers/modeling_utils.py", line 3450, in from_pretrained
model = cls(config, *model_args, **model_kwargs)
File "/home/kkk/DB/libs/mPLUGOwl/mPLUG-Owl2/mplug_owl2/model/modeling_mplug_owl2.py", line 209, in init
self.model = MPLUGOwl2LlamaModel(config)
File "/home/kkk/DB/libs/mPLUGOwl/mPLUG-Owl2/mplug_owl2/model/modeling_mplug_owl2.py", line 201, in init
super(MPLUGOwl2LlamaModel, self).init(config)
File "/home/kkk/DB/libs/mPLUGOwl/mPLUG-Owl2/mplug_owl2/model/modeling_mplug_owl2.py", line 33, in init
super(MPLUGOwl2MetaModel, self).init(config)
File "/data/kkk/anaconda3/envs/va/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 949, in init
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
File "/data/kkk/anaconda3/envs/va/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 949, in
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
TypeError: init() takes 2 positional arguments but 3 were given

@vateye
Copy link

vateye commented Dec 18, 2023

Update your transformer library.

@Tianchong-Jiang
Copy link

Update your transformer library.

I updated to the latest version (transformers==4.36.2) but still have the problem.

@Tianchong-Jiang
Copy link

I solved the problem by using transformers==4.32.0.
Using either 4.36.2 (latest) or 4.28.1 (specified in requirements.txt) caused some errors.

@appledora
Copy link

appledora commented Feb 4, 2024

For the same snippet I got the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[8], line 6
      3 query = "Describe the image."
      5 model_name = get_model_name_from_path(model_path)
----> 6 tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False, load_4bit=False, device="cuda")

File /projectnb/ivc-ml/appledora/mPLUGOwl/mPLUGOwl2/mplug_owl2/model/builder.py:117, in load_pretrained_model(model_path, model_base, model_name, load_8bit, load_4bit, device_map, device, **kwargs)
    115         use_fast = False
    116         tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
--> 117         model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
    120 vision_tower = model.get_model().vision_model
    121 # vision_tower.to(device=device, dtype=torch.float16)

File /projectnb/ivc-ml/appledora/condaenvs/.conda/envs/mplug_owl2/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py:493, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    491 elif type(config) in cls._model_mapping.keys():
    492     model_class = _get_model_class(config, cls._model_mapping)
--> 493     return model_class.from_pretrained(
    494         pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
    495     )
    496 raise ValueError(
    497     f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
    498     f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
    499 )

File /projectnb/ivc-ml/appledora/condaenvs/.conda/envs/mplug_owl2/lib/python3.10/site-packages/transformers/modeling_utils.py:2700, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)
   2697     init_contexts.append(init_empty_weights())
   2699 with ContextManagers(init_contexts):
-> 2700     model = cls(config, *model_args, **model_kwargs)
   2702 # Check first if we are `from_pt`
   2703 if use_keep_in_fp32_modules:

File /projectnb/ivc-ml/appledora/mPLUGOwl/mPLUGOwl2/mplug_owl2/model/modeling_mplug_owl2.py:218, in MPLUGOwl2LlamaForCausalLM.__init__(self, config)
    216 def __init__(self, config):
    217     super(LlamaForCausalLM, self).__init__(config)
--> 218     self.model = MPLUGOwl2LlamaModel(config)
    220     self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
    222     # Initialize weights and apply final processing

File /projectnb/ivc-ml/appledora/mPLUGOwl/mPLUGOwl2/mplug_owl2/model/modeling_mplug_owl2.py:205, in MPLUGOwl2LlamaModel.__init__(self, config)
    204 def __init__(self, config: MPLUGOwl2Config):
--> 205     super(MPLUGOwl2LlamaModel, self).__init__(config)

File /projectnb/ivc-ml/appledora/mPLUGOwl/mPLUGOwl2/mplug_owl2/model/modeling_mplug_owl2.py:36, in MPLUGOwl2MetaModel.__init__(self, config)
     34 def __init__(self, config):
     35     super(MPLUGOwl2MetaModel, self).__init__(config)
---> 36     self.vision_model = MplugOwlVisionModel(
     37         MplugOwlVisionConfig(**config.visual_config["visual_model"])
     38     )
     39     self.visual_abstractor = MplugOwlVisualAbstractorModel(
     40         MplugOwlVisualAbstractorConfig(**config.visual_config["visual_abstractor"]), config.hidden_size
     41     )

File /projectnb/ivc-ml/appledora/mPLUGOwl/mPLUGOwl2/mplug_owl2/model/visual_encoder.py:403, in MplugOwlVisionModel.__init__(self, config)
    400 self.config = config
    401 self.hidden_size = config.hidden_size
--> 403 self.embeddings = MplugOwlVisionEmbeddings(config)
    404 self.encoder = MplugOwlVisionEncoder(config)
    405 if config.use_post_layernorm:

File /projectnb/ivc-ml/appledora/mPLUGOwl/mPLUGOwl2/mplug_owl2/model/visual_encoder.py:105, in MplugOwlVisionEmbeddings.__init__(self, config)
     95     self.cls_token = None
     97 self.patch_embed = nn.Conv2d(
     98     in_channels=3,
     99     out_channels=self.hidden_size,
   (...)
    102     bias=False,
    103 )
--> 105 if self.cls_token:
    106     self.num_patches = (self.image_size // self.patch_size) ** 2
    107     self.position_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, self.hidden_size))

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

I have the following transformer version :
transformers 4.31.0

Later I upgraded it to 4.32.0 as suggested, but error persists.

@findalexli
Copy link

Any one was able to fix this?

@hiker-lw
Copy link

For the same snippet I got the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[8], line 6
      3 query = "Describe the image."
      5 model_name = get_model_name_from_path(model_path)
----> 6 tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False, load_4bit=False, device="cuda")

File /projectnb/ivc-ml/appledora/mPLUGOwl/mPLUGOwl2/mplug_owl2/model/builder.py:117, in load_pretrained_model(model_path, model_base, model_name, load_8bit, load_4bit, device_map, device, **kwargs)
    115         use_fast = False
    116         tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
--> 117         model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
    120 vision_tower = model.get_model().vision_model
    121 # vision_tower.to(device=device, dtype=torch.float16)

File /projectnb/ivc-ml/appledora/condaenvs/.conda/envs/mplug_owl2/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py:493, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    491 elif type(config) in cls._model_mapping.keys():
    492     model_class = _get_model_class(config, cls._model_mapping)
--> 493     return model_class.from_pretrained(
    494         pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
    495     )
    496 raise ValueError(
    497     f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
    498     f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
    499 )

File /projectnb/ivc-ml/appledora/condaenvs/.conda/envs/mplug_owl2/lib/python3.10/site-packages/transformers/modeling_utils.py:2700, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)
   2697     init_contexts.append(init_empty_weights())
   2699 with ContextManagers(init_contexts):
-> 2700     model = cls(config, *model_args, **model_kwargs)
   2702 # Check first if we are `from_pt`
   2703 if use_keep_in_fp32_modules:

File /projectnb/ivc-ml/appledora/mPLUGOwl/mPLUGOwl2/mplug_owl2/model/modeling_mplug_owl2.py:218, in MPLUGOwl2LlamaForCausalLM.__init__(self, config)
    216 def __init__(self, config):
    217     super(LlamaForCausalLM, self).__init__(config)
--> 218     self.model = MPLUGOwl2LlamaModel(config)
    220     self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
    222     # Initialize weights and apply final processing

File /projectnb/ivc-ml/appledora/mPLUGOwl/mPLUGOwl2/mplug_owl2/model/modeling_mplug_owl2.py:205, in MPLUGOwl2LlamaModel.__init__(self, config)
    204 def __init__(self, config: MPLUGOwl2Config):
--> 205     super(MPLUGOwl2LlamaModel, self).__init__(config)

File /projectnb/ivc-ml/appledora/mPLUGOwl/mPLUGOwl2/mplug_owl2/model/modeling_mplug_owl2.py:36, in MPLUGOwl2MetaModel.__init__(self, config)
     34 def __init__(self, config):
     35     super(MPLUGOwl2MetaModel, self).__init__(config)
---> 36     self.vision_model = MplugOwlVisionModel(
     37         MplugOwlVisionConfig(**config.visual_config["visual_model"])
     38     )
     39     self.visual_abstractor = MplugOwlVisualAbstractorModel(
     40         MplugOwlVisualAbstractorConfig(**config.visual_config["visual_abstractor"]), config.hidden_size
     41     )

File /projectnb/ivc-ml/appledora/mPLUGOwl/mPLUGOwl2/mplug_owl2/model/visual_encoder.py:403, in MplugOwlVisionModel.__init__(self, config)
    400 self.config = config
    401 self.hidden_size = config.hidden_size
--> 403 self.embeddings = MplugOwlVisionEmbeddings(config)
    404 self.encoder = MplugOwlVisionEncoder(config)
    405 if config.use_post_layernorm:

File /projectnb/ivc-ml/appledora/mPLUGOwl/mPLUGOwl2/mplug_owl2/model/visual_encoder.py:105, in MplugOwlVisionEmbeddings.__init__(self, config)
     95     self.cls_token = None
     97 self.patch_embed = nn.Conv2d(
     98     in_channels=3,
     99     out_channels=self.hidden_size,
   (...)
    102     bias=False,
    103 )
--> 105 if self.cls_token:
    106     self.num_patches = (self.image_size // self.patch_size) ** 2
    107     self.position_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, self.hidden_size))

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

I have the following transformer version : transformers 4.31.0

Later I upgraded it to 4.32.0 as suggested, but error persists.

hello, you can change to if self.cls_token is not None, it works to me.

@LukeForeverYoung
Copy link
Collaborator

For the same snippet I got the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[8], line 6
      3 query = "Describe the image."
      5 model_name = get_model_name_from_path(model_path)
----> 6 tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False, load_4bit=False, device="cuda")

File /projectnb/ivc-ml/appledora/mPLUGOwl/mPLUGOwl2/mplug_owl2/model/builder.py:117, in load_pretrained_model(model_path, model_base, model_name, load_8bit, load_4bit, device_map, device, **kwargs)
    115         use_fast = False
    116         tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
--> 117         model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
    120 vision_tower = model.get_model().vision_model
    121 # vision_tower.to(device=device, dtype=torch.float16)

File /projectnb/ivc-ml/appledora/condaenvs/.conda/envs/mplug_owl2/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py:493, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    491 elif type(config) in cls._model_mapping.keys():
    492     model_class = _get_model_class(config, cls._model_mapping)
--> 493     return model_class.from_pretrained(
    494         pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
    495     )
    496 raise ValueError(
    497     f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
    498     f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
    499 )

File /projectnb/ivc-ml/appledora/condaenvs/.conda/envs/mplug_owl2/lib/python3.10/site-packages/transformers/modeling_utils.py:2700, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)
   2697     init_contexts.append(init_empty_weights())
   2699 with ContextManagers(init_contexts):
-> 2700     model = cls(config, *model_args, **model_kwargs)
   2702 # Check first if we are `from_pt`
   2703 if use_keep_in_fp32_modules:

File /projectnb/ivc-ml/appledora/mPLUGOwl/mPLUGOwl2/mplug_owl2/model/modeling_mplug_owl2.py:218, in MPLUGOwl2LlamaForCausalLM.__init__(self, config)
    216 def __init__(self, config):
    217     super(LlamaForCausalLM, self).__init__(config)
--> 218     self.model = MPLUGOwl2LlamaModel(config)
    220     self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
    222     # Initialize weights and apply final processing

File /projectnb/ivc-ml/appledora/mPLUGOwl/mPLUGOwl2/mplug_owl2/model/modeling_mplug_owl2.py:205, in MPLUGOwl2LlamaModel.__init__(self, config)
    204 def __init__(self, config: MPLUGOwl2Config):
--> 205     super(MPLUGOwl2LlamaModel, self).__init__(config)

File /projectnb/ivc-ml/appledora/mPLUGOwl/mPLUGOwl2/mplug_owl2/model/modeling_mplug_owl2.py:36, in MPLUGOwl2MetaModel.__init__(self, config)
     34 def __init__(self, config):
     35     super(MPLUGOwl2MetaModel, self).__init__(config)
---> 36     self.vision_model = MplugOwlVisionModel(
     37         MplugOwlVisionConfig(**config.visual_config["visual_model"])
     38     )
     39     self.visual_abstractor = MplugOwlVisualAbstractorModel(
     40         MplugOwlVisualAbstractorConfig(**config.visual_config["visual_abstractor"]), config.hidden_size
     41     )

File /projectnb/ivc-ml/appledora/mPLUGOwl/mPLUGOwl2/mplug_owl2/model/visual_encoder.py:403, in MplugOwlVisionModel.__init__(self, config)
    400 self.config = config
    401 self.hidden_size = config.hidden_size
--> 403 self.embeddings = MplugOwlVisionEmbeddings(config)
    404 self.encoder = MplugOwlVisionEncoder(config)
    405 if config.use_post_layernorm:

File /projectnb/ivc-ml/appledora/mPLUGOwl/mPLUGOwl2/mplug_owl2/model/visual_encoder.py:105, in MplugOwlVisionEmbeddings.__init__(self, config)
     95     self.cls_token = None
     97 self.patch_embed = nn.Conv2d(
     98     in_channels=3,
     99     out_channels=self.hidden_size,
   (...)
    102     bias=False,
    103 )
--> 105 if self.cls_token:
    106     self.num_patches = (self.image_size // self.patch_size) ** 2
    107     self.position_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, self.hidden_size))

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

I have the following transformer version : transformers 4.31.0
Later I upgraded it to 4.32.0 as suggested, but error persists.

hello, you can change to if self.cls_token is not None, it works to me.

Yes, this issue is incorporated by the mPLUG-Owl2.1 which disables the cls_token in visual encoder. We fixed this issue in the latest commit.

@appledora
Copy link

Yes, i ran last week too by turning off the cls_token check. Glad that it is now officially handled too!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants