You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I ran the chatglm3-6b model by exporting it to ONNX framework using custom onnx configuration. Although the functionality is correct, the latency of the model is very high, much higher than the pytorch model.
I have attached a minimal reproducible code which exports and run the model. Can someone take a look into it and suggest how to rectify the performance degradation.
from optimum.exporters.onnx import main_export
from transformers import AutoConfig
from optimum.exporters.onnx.config import TextDecoderOnnxConfig,TextDecoderWithPositionIdsOnnxConfig
from optimum.exporters.onnx.base import ConfigBehavior
from optimum.utils import NormalizedTextConfig, DummyPastKeyValuesGenerator
from typing import Dict
import os
import shutil
import time
class ChatGLM2DummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def generate(self, input_name: str, framework: str = "pt"):
past_key_shape = (
self.batch_size,
self.num_attention_heads,
self.hidden_size // self.num_attention_heads,
self.sequence_length,
)
past_value_shape = (
self.batch_size,
self.num_attention_heads,
self.sequence_length,
self.hidden_size // self.num_attention_heads,
)
return [
(
self.random_float_tensor(past_key_shape, framework=framework),
self.random_float_tensor(past_value_shape, framework=framework),
)
for _ in range(self.num_layers)
]
class CustomChatGLM2OnnxConfig(TextDecoderOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
ChatGLM2DummyPastKeyValuesGenerator,
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
DUMMY_PKV_GENERATOR_CLASS = ChatGLM2DummyPastKeyValuesGenerator
DEFAULT_ONNX_OPSET = 15 # aten::tril operator requires opset>=14
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
hidden_size="hidden_size",
num_layers="num_layers",
num_attention_heads="num_attention_heads",
)
def add_past_key_values(
self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str
):
if direction not in ["inputs", "outputs"]:
raise ValueError(
f'direction must either be "inputs" or "outputs", but {direction} was given'
)
if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + 1"
name = "present"
for i in range(self._normalized_config.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {
0: "batch_size",
3: decoder_sequence_name,
}
inputs_or_outputs[f"{name}.{i}.value"] = {
0: "batch_size",
2: decoder_sequence_name,
}
model_id = "THUDM/chatglm3-6b"
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
onnx_config = CustomChatGLM2OnnxConfig(
config=config,
task="text-generation",
use_past_in_inputs=False,
)
onnx_config_with_past = CustomChatGLM2OnnxConfig(
config, task="text-generation", use_past=True
)
custom_onnx_configs = {
"model": onnx_config,
}
main_export(
model_id,
output="chatglm",
task="text-generation-with-past",
trust_remote_code=True,
custom_onnx_configs=custom_onnx_configs,
no_post_process=True,
opset=15
)
### Running
from transformers import AutoTokenizer, AutoModelForCausalLM
from optimum.utils import NormalizedTextConfig, NormalizedConfigManager
NormalizedConfigManager._conf["chatglm"] = NormalizedTextConfig
import torch
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
start = time.perf_counter()
inputs = tokenizer("What is the meaning of life?", return_tensors="pt", padding=True)
input_ids = inputs.input_ids
# Generate
generate_ids = model.generate(
input_ids,
max_length=64,
pad_token_id=tokenizer.eos_token_id,
)
# Stop timer
end = time.perf_counter()
generate_time = end - start
# Num of tokens
prompt_tokens = input_ids.shape[1]
num_tokens_out = generate_ids.shape[1]
new_tokens_generated = num_tokens_out - prompt_tokens
time_per_token = (generate_time / new_tokens_generated) * 1e3
print(time_per_token)
The text was updated successfully, but these errors were encountered:
I ran the chatglm3-6b model by exporting it to ONNX framework using custom onnx configuration. Although the functionality is correct, the latency of the model is very high, much higher than the pytorch model.
I have attached a minimal reproducible code which exports and run the model. Can someone take a look into it and suggest how to rectify the performance degradation.
The text was updated successfully, but these errors were encountered: