Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Low performance of THUDM/chatglm3-6b onnx model #1846

Open
tuhinpahari opened this issue May 6, 2024 · 0 comments
Open

Low performance of THUDM/chatglm3-6b onnx model #1846

tuhinpahari opened this issue May 6, 2024 · 0 comments

Comments

@tuhinpahari
Copy link

I ran the chatglm3-6b model by exporting it to ONNX framework using custom onnx configuration. Although the functionality is correct, the latency of the model is very high, much higher than the pytorch model.
I have attached a minimal reproducible code which exports and run the model. Can someone take a look into it and suggest how to rectify the performance degradation.

from optimum.exporters.onnx import main_export
from transformers import AutoConfig

from optimum.exporters.onnx.config import TextDecoderOnnxConfig,TextDecoderWithPositionIdsOnnxConfig
from optimum.exporters.onnx.base import ConfigBehavior
from optimum.utils import NormalizedTextConfig, DummyPastKeyValuesGenerator
from typing import Dict
import os
import shutil
import time


class ChatGLM2DummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):

    def generate(self, input_name: str, framework: str = "pt"):
        past_key_shape = (
            self.batch_size,
            self.num_attention_heads,
            self.hidden_size // self.num_attention_heads,
            self.sequence_length,
        )
        past_value_shape = (
            self.batch_size,
            self.num_attention_heads,
            self.sequence_length,
            self.hidden_size // self.num_attention_heads,
        )
        return [
            (
                self.random_float_tensor(past_key_shape, framework=framework),
                self.random_float_tensor(past_value_shape, framework=framework),
            )
            for _ in range(self.num_layers)
        ]


class CustomChatGLM2OnnxConfig(TextDecoderOnnxConfig):
    DUMMY_INPUT_GENERATOR_CLASSES = (
        ChatGLM2DummyPastKeyValuesGenerator,
    ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
    DUMMY_PKV_GENERATOR_CLASS = ChatGLM2DummyPastKeyValuesGenerator

    DEFAULT_ONNX_OPSET = 15  # aten::tril operator requires opset>=14
    NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
        hidden_size="hidden_size",
        num_layers="num_layers",
        num_attention_heads="num_attention_heads",
    )

    def add_past_key_values(
        self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str
    ):

        if direction not in ["inputs", "outputs"]:
            raise ValueError(
                f'direction must either be "inputs" or "outputs", but {direction} was given'
            )

        if direction == "inputs":
            decoder_sequence_name = "past_sequence_length"
            name = "past_key_values"
        else:
            decoder_sequence_name = "past_sequence_length + 1"
            name = "present"

        for i in range(self._normalized_config.num_layers):
            inputs_or_outputs[f"{name}.{i}.key"] = {
                0: "batch_size",
                3: decoder_sequence_name,
            }
            inputs_or_outputs[f"{name}.{i}.value"] = {
                0: "batch_size",
                2: decoder_sequence_name,
            }

model_id = "THUDM/chatglm3-6b"
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) 

onnx_config = CustomChatGLM2OnnxConfig(
                config=config,
                task="text-generation",
                use_past_in_inputs=False,
            )
onnx_config_with_past = CustomChatGLM2OnnxConfig(
                config, task="text-generation", use_past=True
            )

custom_onnx_configs = {
                "model": onnx_config,
            }

main_export(
    model_id,
    output="chatglm",
    task="text-generation-with-past",
    trust_remote_code=True,
    custom_onnx_configs=custom_onnx_configs,
    no_post_process=True,
    opset=15
)

### Running 

from transformers import AutoTokenizer, AutoModelForCausalLM
from optimum.utils import NormalizedTextConfig, NormalizedConfigManager
NormalizedConfigManager._conf["chatglm"] = NormalizedTextConfig

import torch

tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)

start = time.perf_counter()

inputs = tokenizer("What is the meaning of life?", return_tensors="pt", padding=True)
input_ids = inputs.input_ids

# Generate
generate_ids = model.generate(
               input_ids,
               max_length=64,
               pad_token_id=tokenizer.eos_token_id,
            )

      
# Stop timer
end = time.perf_counter()
generate_time = end - start

# Num of tokens
prompt_tokens = input_ids.shape[1]
num_tokens_out = generate_ids.shape[1]
new_tokens_generated = num_tokens_out - prompt_tokens

time_per_token = (generate_time / new_tokens_generated) * 1e3

print(time_per_token)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant