Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add LLava ONNX export #1790

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft

Conversation

mht-sharma
Copy link
Contributor

@mht-sharma mht-sharma commented Apr 2, 2024

What does this PR do?

As per title!

Issue: (#1751)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines +530 to +552
if config._behavior == "encoder":
inputs_embeds = model.get_input_embeddings()(input_ids)

image_outputs = model.vision_tower(pixel_values, output_hidden_states=True)
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]

if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(f"Unexpected select feature strategy: {vision_feature_select_strategy}")

image_features = model.multi_modal_projector(selected_image_feature)
inputs_embeds, attention_mask, labels, position_ids = model._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, None
)

result = {
"inputs_embeds": inputs_embeds,
"decoder_attention_mask": attention_mask,
"position_ids": position_ids,
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might not be understanding this 100%, but won't this be problematic for generation? We would need to re-pass the image features on every forward pass, which will merge the ids every time. This also means that we cannot embed a single text token (e.g., the one just generated).

Here's an example of a hand-crafted version of a tiny random LlavaForConditionalGeneration: https://huggingface.co/Xenova/tiny-random-LlavaForConditionalGeneration. There are 3 models exported:

I've got this working with Transformers.js (v3), where the concatenation of the token/vision patch embeddings are done in JavaScript.

Copy link
Contributor Author

@mht-sharma mht-sharma Apr 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @xenova, it should not be a problem for generation.

I generate the following three models:

  1. encoder_model.onnx - token embed + vision tower + projection + merging
  2. decoder_model.onnx - Language model only (The export is same as current decoder export in optimum)
  3. decoder_input_processor.onnx - token embed + decoder input generation when past_key_values is available. (The elif part in the modeling code)

The naming of models could possibly be updated.

This is how I use the models for inference: https://gist.github.com/mht-sharma/290f7bf9052e92023b4136c6fefd6717

ONNX Model: https://huggingface.co/mohitsha/llava-1.5-7b-hf/tree/main

In this version:

  1. I do all calculations as part of ONNX.
  2. The embedding model is duplicated but is comparatively small. If we want we could have additional 2 options for this part:
    a. Create a separate embed_model.onnx and rest same. Now we have 4 ONNX models.
    b. Create a separate embed_model.onnx and do the past_key_value stage attention_mask and position_ids processing as part of python code and remove decoder_input_processor.onnx

Let me know WDYT and if you have any suggestions.

@xenova
Copy link
Contributor

xenova commented Apr 4, 2024

It might also be a good idea to generalize for other image-text-to-text models. For example, vikhyatk/moondream2 which is quite similar (or others that are actually supported by transformers).

@fxmarty fxmarty mentioned this pull request Apr 16, 2024
4 tasks
@Pengjie-W
Copy link

Could you please give me the code for converting llava into onnx

@Pengjie-W
Copy link

Because I'm going to make an error, RuntimeError: The size of tensor a (4112) must match the size of tensor b (32) at non-singleton dimension 3

Could you please give me the code for converting llava into onnx

Because I'm going to make an error, RuntimeError: The size of tensor a (4112) must match the size of tensor b (32) at non-singleton dimension 3

@Pengjie-W
Copy link

Because I'm going to make an error, RuntimeError: The size of tensor a (4112) must match the size of tensor b (32) at non-singleton dimension 3

Could you please give me the code for converting llava into onnx

Because I'm going to make an error, RuntimeError: The size of tensor a (4112) must match the size of tensor b (32) at non-singleton dimension 3

Traceback (most recent call last):
File "/home/user/anaconda3/envs/llava/lib/python3.10/site-packages/optimum/exporters/onnx/convert.py", line 577, in export_pytorch
onnx_export(
File "/home/user/anaconda3/envs/llava/lib/python3.10/site-packages/torch/onnx/utils.py", line 516, in export
_export(
File "/home/user/anaconda3/envs/llava/lib/python3.10/site-packages/torch/onnx/utils.py", line 1596, in _export
graph, params_dict, torch_out = _model_to_graph(
File "/home/user/anaconda3/envs/llava/lib/python3.10/site-packages/torch/onnx/utils.py", line 1135, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
File "/home/user/anaconda3/envs/llava/lib/python3.10/site-packages/torch/onnx/utils.py", line 1011, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/home/user/anaconda3/envs/llava/lib/python3.10/site-packages/torch/onnx/utils.py", line 915, in _trace_and_get_graph_from_model
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
File "/home/user/anaconda3/envs/llava/lib/python3.10/site-packages/torch/jit/_trace.py", line 1285, in _get_trace_graph
outs = ONNXTracedModule(
File "/home/user/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/user/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/user/anaconda3/envs/llava/lib/python3.10/site-packages/torch/jit/_trace.py", line 133, in forward
graph, out = torch._C._create_graph_by_tracing(
File "/home/user/anaconda3/envs/llava/lib/python3.10/site-packages/torch/jit/_trace.py", line 124, in wrapper
outs.append(self.inner(*trace_inputs))
File "/home/user/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/user/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/user/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/user/anaconda3/envs/llava/lib/python3.10/site-packages/optimum/exporters/onnx/model_patcher.py", line 589, in patched_forward
outputs = model.language_model(
File "/home/user/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/user/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/user/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1183, in forward
outputs = self.model(
File "/home/user/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/user/anaconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/user/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1035, in forward
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
File "/home/user/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/modeling_attn_mask_utils.py", line 398, in _prepare_4d_causal_attention_mask_for_sdpa
expanded_4d_mask = attn_mask_converter.to_4d(
File "/home/user/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/modeling_attn_mask_utils.py", line 137, in to_4d
expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
RuntimeError: The size of tensor a (4112) must match the size of tensor b (32) at non-singleton dimension 3
python-BaseException

@Pengjie-W
Copy link

I'm running this
optimum-cli export onnx --model llava-hf/llava-1.5-7b-hf llava_onnx/ --task image-to-text-with-past --trust-remote-code
Reported error

@mht-sharma
Copy link
Contributor Author

Hi @Pengjie-W I will have a look later today or Monday!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants