Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama 3 support #728

Open
2 tasks done
bitterspeed opened this issue Apr 25, 2024 · 1 comment
Open
2 tasks done

Llama 3 support #728

bitterspeed opened this issue Apr 25, 2024 · 1 comment
Labels
new model Request a new model

Comments

@bitterspeed
Copy link

Model description

Hi all,
I'm attempting to convert Llama-3 to ONNX format.

Prerequisites

  • The model is supported in Transformers (i.e., listed here)
  • The model can be exported to ONNX with Optimum (i.e., listed here)

Additional information

No response

Your contribution

Upon running, python convert.py --quantize --model_id meta-llama/Meta-Llama-3-8B-Instruct - I get this error:

/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/transformers/utils/generic.py:311: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  torch.utils._pytree._register_pytree_node(
Framework not specified. Using pt to export to ONNX.
model-00001-of-00004.safetensors: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4.98G/4.98G [15:34<00:00, 5.33MB/s]
model-00002-of-00004.safetensors: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5.00G/5.00G [08:11<00:00, 10.2MB/s]model-00003-of-00004.safetensors: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4.92G/4.92G [08:18<00:00, 9.86MB/s]
model-00004-of-00004.safetensors: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.17G/1.17G [01:23<00:00, 13.9MB/s]
Downloading shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [33:30<00:00, 502.67s/it]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:51<00:00, 12.92s/it]
generation_config.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 187/187 [00:00<00:00, 611kB/s]
Automatic task detection to text-generation-with-past (possible synonyms are: causal-lm-with-past).
Using the export variant default. Available variants are:
	- default: The default ONNX variant.
use_past = False is different than use_present_in_outputs = True, the value of use_present_in_outputs value will be used for the outputs.
Using framework PyTorch: 2.3.0
Overriding 1 configuration item(s)
	- use_cache -> True
/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:595: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if input_shape[-1] > 1:
/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:119: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if seq_len > self.max_seq_len_cached:
/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:348: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:355: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:365: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
Saving external data to one file...
Using framework PyTorch: 2.3.0
Overriding 1 configuration item(s)
	- use_cache -> True
Asked a sequence length of 16, but a sequence length of 1 will be used with use_past == True for `input_ids`.
Traceback (most recent call last):
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/convert.py", line 545, in <module>
    main()
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/convert.py", line 448, in main
    main_export(**export_kwargs)
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/optimum/exporters/onnx/__main__.py", line 486, in main_export
    _, onnx_outputs = export_models(
                      ^^^^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/optimum/exporters/onnx/convert.py", line 752, in export_models
    export(
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/optimum/exporters/onnx/convert.py", line 855, in export
    export_output = export_pytorch(
                    ^^^^^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/optimum/exporters/onnx/convert.py", line 572, in export_pytorch
    onnx_export(
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/torch/onnx/utils.py", line 516, in export
    _export(
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/torch/onnx/utils.py", line 1612, in _export
    graph, params_dict, torch_out = _model_to_graph(
                                    ^^^^^^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/torch/onnx/utils.py", line 1134, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/torch/onnx/utils.py", line 1010, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/torch/onnx/utils.py", line 914, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/torch/jit/_trace.py", line 1310, in _get_trace_graph
    outs = ONNXTracedModule(
           ^^^^^^^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/torch/jit/_trace.py", line 138, in forward
    graph, out = torch._C._create_graph_by_tracing(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/torch/jit/_trace.py", line 129, in wrapper
    outs.append(self.inner(*trace_inputs))
                ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1522, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/optimum/exporters/onnx/model_patcher.py", line 113, in patched_forward
    outputs = self.orig_forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 820, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1522, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 708, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1522, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 424, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                          ^^^^^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1522, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/goodspeed/Downloads/transformers.js-main/scripts/myenv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 337, in forward
    key_states = torch.cat([past_key_value[0], key_states], dim=2)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 32 but got size 8 for tensor number 1 in the list.

Any ideas?

@bitterspeed bitterspeed added the new model Request a new model label Apr 25, 2024
@xenova
Copy link
Owner

xenova commented Apr 25, 2024

Looks like an issue with dummy input values due to the adoption of grouped query attention (GQA). Since this is an issue with Optimum, can you reopen the question there (link)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
new model Request a new model
Projects
None yet
Development

No branches or pull requests

2 participants