Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BeamSearch op returning wrong results on CUDA execution provider when sequence is used as input_ids #20667

Open
amancini-N opened this issue May 13, 2024 · 1 comment · May be fixed by #20668
Labels
ep:CUDA issues related to the CUDA execution provider model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.

Comments

@amancini-N
Copy link

Describe the issue

com.microsoft::BeamSearch op is outputting wrong values when following conditions are satisfied:

  • Running on CUDA execution provider
  • Using model_type = 1 (T5-like model)
  • input_ids of decoder_step graph has 2nd dimension unbounded, which means use_sequence_as_input_ids is True

Looking with DEBUG_GENERATION enabled, it seems the problem lies in the copy of the sequence tensor to input_ids when feeding the decoder graph.
The copy is done from host to device:

but the sequences span points already to GPU memory. This should be changed to a device-to-device copy instead

To reproduce

Following script creates a dummy model and tests it with both CPU and GPU EPs:

import onnx
import onnxruntime
import numpy as np

VOCAB_SIZE = 20
NUM_HEADS = 2
HEAD_SIZE = 4
HIDDEN_SIZE = NUM_HEADS * HEAD_SIZE

def create_encoder_graph(encoder_embedding_weight, decoder_embedding_weight, decoder_linear_weight):
    # Create an ONNX graph taking as input:
    # - an int sequence called encoder_input_ids: (batch_size, encode_sequence_length)
    # - a mask sequence called encoder_attention_mask
    # - a BOS token called decoder_input_ids
    # and returning as outputs:
    # - the encoder output: (batch_size, encode_sequence_length, hidden_size)
    # - the decoder output: (batch_size, 1, vocab_size)
    # - the output key state from the decoder named present_key_self_0: (batch_size, num_heads, 1, head_size)
    # - the output value state from the decoder named present_value_self_0: (batch_size, num_heads, 1, head_size)
    # - the output key state from the encoder named present_key_cross_0: (batch_size, num_heads, encode_sequence_length, head_size)
    # - the output value state from the encoder named present_value_cross_0: (batch_size, num_heads, encode_sequence_length, head_size)
    # num_heads = 2
    # head_size = 8
    # hidden_size = 16
    # vocab_size = 50
    # For simplicity, encoder only contains an embedding layer
    # Decoder instead contains an embedding layer, a linear projection and a sum between it and the encoder output. Final output is a softmax layer
    input_name = "encoder_input_ids"

    # Create input tensor
    input_tensor = onnx.helper.make_tensor_value_info(input_name, onnx.TensorProto.INT32, ["batch_size", "encode_sequence_length"])
    mask_tensor = onnx.helper.make_tensor_value_info("encoder_attention_mask", onnx.TensorProto.INT32, ["batch_size", "encode_sequence_length"])

    # Create embedding layer
    embedding_node = onnx.helper.make_node("Gather", ["embedding_weight", input_name], ["encoder_hidden_states"], name="embedding")
    embedding_weight_initializer = onnx.helper.make_tensor("embedding_weight", onnx.TensorProto.FLOAT, [VOCAB_SIZE, HIDDEN_SIZE], encoder_embedding_weight.flatten())

    # Create encoder layer
    encoder_output = onnx.helper.make_tensor_value_info("encoder_hidden_states", onnx.TensorProto.FLOAT, ["batch_size", "encode_sequence_length", HIDDEN_SIZE])

    # Create decoder input
    decoder_input = onnx.helper.make_tensor_value_info("decoder_input_ids", onnx.TensorProto.INT32, ["batch_size", 1])

    # Create decoder embedding layer
    decoder_embedding_node = onnx.helper.make_node("Gather", ["decoder_embedding_weight", "decoder_input_ids"], ["decoder_embedding_output"], name="decoder_embedding")
    decoder_embedding_weight_initializer = onnx.helper.make_tensor("decoder_embedding_weight", onnx.TensorProto.FLOAT, [VOCAB_SIZE, HIDDEN_SIZE], decoder_embedding_weight.flatten())

    # Create decoder output
    decoder_output = onnx.helper.make_tensor_value_info("logits", onnx.TensorProto.FLOAT, ["batch_size", 1, VOCAB_SIZE])

    # Reduce mean of encoder output
    encoder_output_mean = onnx.helper.make_node("ReduceMean", ["encoder_hidden_states"], ["encoder_hidden_states_mean"], axes=[1])

    # Create sum node
    sum_node = onnx.helper.make_node("Add", ["decoder_embedding_output", "encoder_hidden_states_mean"], ["sum_output"], name="sum")

    # Create linear projection
    linear_node = onnx.helper.make_node("MatMul", ["sum_output", "W"], ["linear_output"], name="linear")
    linear_weight_initializer = onnx.helper.make_tensor("W", onnx.TensorProto.FLOAT, [HIDDEN_SIZE, VOCAB_SIZE], decoder_linear_weight.flatten())

    # Create softmax node
    softmax_node = onnx.helper.make_node("Softmax", ["linear_output"], ["logits"], name="softmax")

    # Create output key and value states
    present_self_key = onnx.helper.make_tensor_value_info("present_key_self_0", onnx.TensorProto.FLOAT, ["batch_size", NUM_HEADS, 1, HEAD_SIZE])
    present_self_value = onnx.helper.make_tensor_value_info("present_value_self_0", onnx.TensorProto.FLOAT, ["batch_size", NUM_HEADS, 1, HEAD_SIZE])

    # Obtain key and value states from reshaping the encoder/decoder sum
    final_shape_as_constant = onnx.helper.make_node("Constant", [], ["final_shape"], value=onnx.helper.make_tensor("value", onnx.TensorProto.INT64, [4], [-1, 1, NUM_HEADS, HEAD_SIZE]))
    key_node = onnx.helper.make_node("Reshape", ["sum_output", "final_shape"], ["to_transpose_self_key"], name="key_reshape")
    value_node = onnx.helper.make_node("Reshape", ["sum_output", "final_shape"], ["to_transpose_self_value"], name="value_reshape")
    transposed_key_node = onnx.helper.make_node("Transpose", ["to_transpose_self_key"], ["present_key_self_0"], perm=[0, 2, 1, 3])
    transposed_value_node = onnx.helper.make_node("Transpose", ["to_transpose_self_value"], ["present_value_self_0"], perm=[0, 2, 1, 3])

    # Create output key and value states from the encoder
    present_cross_key = onnx.helper.make_tensor_value_info("present_cross_key", onnx.TensorProto.FLOAT, ["batch_size", NUM_HEADS, "encode_sequence_length", HEAD_SIZE])
    present_cross_value = onnx.helper.make_tensor_value_info("present_cross_value", onnx.TensorProto.FLOAT, ["batch_size", NUM_HEADS, "encode_sequence_length", HEAD_SIZE])

    # Obtain key and value states from reshaping the encoder output
    encoder_batch_seq_len = onnx.helper.make_node("Shape", ["encoder_hidden_states"], ["encoder_batch_seq_len"], end=2)
    num_heads_and_size = onnx.helper.make_node("Constant", [], ["num_heads_and_size"], value=onnx.helper.make_tensor("value", onnx.TensorProto.INT64, [2], [NUM_HEADS, HEAD_SIZE]))
    encoder_final_shape = onnx.helper.make_node("Concat", ["encoder_batch_seq_len", "num_heads_and_size"], ["encoder_final_shape"], axis=0)
    encoder_key_node = onnx.helper.make_node("Reshape", ["encoder_hidden_states", "encoder_final_shape"], ["to_transpose_cross_key"], name="encoder_key_reshape")
    encoder_value_node = onnx.helper.make_node("Reshape", ["encoder_hidden_states", "encoder_final_shape"], ["to_transpose_cross_value"], name="encoder_value_reshape")
    encoder_transposed_key_node = onnx.helper.make_node("Transpose", ["to_transpose_cross_key"], ["present_cross_key"], perm=[0, 2, 1, 3])
    encoder_transposed_value_node = onnx.helper.make_node("Transpose", ["to_transpose_cross_value"], ["present_cross_value"], perm=[0, 2, 1, 3])


    # Create graph
    graph = onnx.helper.make_graph(
        nodes=[final_shape_as_constant, embedding_node, decoder_embedding_node, encoder_output_mean, sum_node, linear_node, softmax_node, key_node, value_node, encoder_batch_seq_len, num_heads_and_size, encoder_final_shape, encoder_key_node, encoder_value_node, transposed_key_node, transposed_value_node, encoder_transposed_key_node, encoder_transposed_value_node],
        name="encoder_decoder_init",
        inputs=[input_tensor, mask_tensor, decoder_input],
        outputs=[decoder_output, encoder_output, present_self_key, present_self_value, present_cross_key, present_cross_value],
        initializer=[embedding_weight_initializer, decoder_embedding_weight_initializer, linear_weight_initializer]
    )

    return graph

def create_decoder_graph(decoder_embedding_weight, decoder_linear_weight):
    # Create an ONNX graph taking as input:
    # - a int sequence called input_ids: (batch_size, decode_sequence_length)
    # - a mask sequence called encoder_attention_mask: (batch_size, encode_sequence_length)
    # - a float tensor called encoder_hidden_states: (batch_size, encode_sequence_length, hidden_size)
    # - a float tensor called past_self_key: (batch_size, num_heads, decode_sequence_length, head_size)
    # - a float tensor called past_self_value: (batch_size, num_heads, decode_sequence_length, head_size)
    # - a float tensor called past_cross_key: (batch_size, num_heads, encode_sequence_length, head_size)
    # - a float tensor called past_cross_value: (batch_size, num_heads, encode_sequence_length, head_size)
    # and returning as outputs:
    # - the decoder output: (batch_size, decode_sequence_length, vocab_size)
    # - the output key state from the decoder, named present_self_key: (batch_size, num_heads, present_sequence_length, head_size)
    # - the output value state from the decoder, named present_self_value: (batch_size, num_heads, present_sequence_length, head_size)
    # hidden_size = 16
    # vocab_size = 50
    # num_heads = 2
    # head_size = 8
    # Decoder contains an embedding layer, a linear projection and a sum between it and the encoder output. Final output is a softmax layer

    # Create input tensor
    input_name = "input_ids"
    input_tensor = onnx.helper.make_tensor_value_info(input_name, onnx.TensorProto.INT32, ["batch_size", "decode_sequence_length"])
    encoder_attention_mask = onnx.helper.make_tensor_value_info("encoder_attention_mask", onnx.TensorProto.INT32, ["batch_size", "encode_sequence_length"])
    encoder_hidden_states = onnx.helper.make_tensor_value_info("encoder_hidden_states", onnx.TensorProto.FLOAT, ["batch_size", "encode_sequence_length", HIDDEN_SIZE])
    past_self_key = onnx.helper.make_tensor_value_info("past_self_key", onnx.TensorProto.FLOAT, ["batch_size", NUM_HEADS, "decode_sequence_length", HEAD_SIZE])
    past_self_value = onnx.helper.make_tensor_value_info("past_self_value", onnx.TensorProto.FLOAT, ["batch_size", NUM_HEADS, "decode_sequence_length", HEAD_SIZE])
    past_cross_key = onnx.helper.make_tensor_value_info("past_cross_key", onnx.TensorProto.FLOAT, ["batch_size", NUM_HEADS, "encode_sequence_length", HEAD_SIZE])
    past_cross_value = onnx.helper.make_tensor_value_info("past_cross_value", onnx.TensorProto.FLOAT, ["batch_size", NUM_HEADS, "encode_sequence_length", HEAD_SIZE])

    # Create decoder embedding layer
    decoder_embedding_node = onnx.helper.make_node("Gather", ["decoder_embedding_weight", input_name], ["decoder_embedding_output"], name="decoder_embedding")
    decoder_embedding_weight_initializer = onnx.helper.make_tensor("decoder_embedding_weight", onnx.TensorProto.FLOAT, [VOCAB_SIZE, HIDDEN_SIZE], decoder_embedding_weight.flatten())

    # Create decoder output
    logits = onnx.helper.make_tensor_value_info("logits", onnx.TensorProto.FLOAT, ["batch_size", "decode_sequence_length", VOCAB_SIZE])

    # Reduce mean of encoder output
    encoder_output_mean = onnx.helper.make_node("ReduceMean", ["encoder_hidden_states"], ["encoder_hidden_states_mean"], axes=[1])

    # Create sum node
    sum_node = onnx.helper.make_node("Add", ["decoder_embedding_output", "encoder_hidden_states_mean"], ["sum_output"], name="sum")

    # Create linear projection
    linear_node = onnx.helper.make_node("MatMul", ["sum_output", "W"], ["linear_output"], name="linear")
    linear_weight_initializer = onnx.helper.make_tensor("W", onnx.TensorProto.FLOAT, [HIDDEN_SIZE, VOCAB_SIZE], decoder_linear_weight.flatten())

    # Create softmax node
    softmax_node = onnx.helper.make_node("Softmax", ["linear_output"], ["logits"], name="softmax")

    # Create output key and value states
    output_key = onnx.helper.make_tensor_value_info("present_key", onnx.TensorProto.FLOAT, ["batch_size", NUM_HEADS, "present_sequence_length", HEAD_SIZE])
    output_value = onnx.helper.make_tensor_value_info("present_value", onnx.TensorProto.FLOAT, ["batch_size", NUM_HEADS, "present_sequence_length", HEAD_SIZE])

    # Obtain key and value states from reshaping the encoder/decoder sum, concatenate with past_self_key and past_self_value
    # First, build a tensor containing the final shape: (batch_size, -1, NUM_HEADS, HEAD_SIZE)
    batch_size = onnx.helper.make_node("Shape", ["sum_output"], ["batch_size"], end=1)
    final_shape_without_batch = onnx.helper.make_node("Constant", [], ["final_shape_without_batch"], value=onnx.helper.make_tensor("value", onnx.TensorProto.INT64, [3], [-1, NUM_HEADS, HEAD_SIZE]))
    final_shape = onnx.helper.make_node("Concat", ["batch_size", "final_shape_without_batch"], ["final_shape"], axis=0)

    key_node = onnx.helper.make_node("Reshape", ["sum_output", "final_shape"], ["output_key"], name="key_reshape")
    value_node = onnx.helper.make_node("Reshape", ["sum_output", "final_shape"], ["output_value"], name="value_reshape")
    transposed_key_node = onnx.helper.make_node("Transpose", ["output_key"], ["output_key_transposed"], perm=[0, 2, 1, 3])
    transposed_value_node = onnx.helper.make_node("Transpose", ["output_value"], ["output_value_transposed"], perm=[0, 2, 1, 3])
    key_concat_node = onnx.helper.make_node("Concat", ["past_self_key", "output_key_transposed"], ["present_key"], axis=2)
    value_concat_node = onnx.helper.make_node("Concat", ["past_self_value", "output_value_transposed"], ["present_value"], axis=2)

    # Create graph
    graph = onnx.helper.make_graph(
        nodes=[decoder_embedding_node, encoder_output_mean, sum_node, batch_size, final_shape_without_batch, final_shape, linear_node, softmax_node, key_node, value_node, transposed_key_node, transposed_value_node, key_concat_node, value_concat_node],
        name="decoder_step",
        inputs=[input_tensor, encoder_attention_mask, encoder_hidden_states, past_self_key, past_self_value, past_cross_key, past_cross_value],
        outputs=[logits, output_key, output_value],
        initializer=[decoder_embedding_weight_initializer, linear_weight_initializer]
    )

    return graph

def create_model_with_beam_search():
    # Create an ONNX model with two subgraphs: encoder and decoder
    # Encoder contains an embedding layer
    # Decoder contains an embedding layer and a sum between it and the encoder output. Final output is a softmax layer
    encoder_embedding_weight = np.random.rand(VOCAB_SIZE, HIDDEN_SIZE)
    # encoder_embedding_weight = np.arange(VOCAB_SIZE*HIDDEN_SIZE).reshape(VOCAB_SIZE, HIDDEN_SIZE).astype(np.float32)
    decoder_embedding_weight = np.random.rand(VOCAB_SIZE, HIDDEN_SIZE)
    # decoder_embedding_weight = np.arange(0, VOCAB_SIZE*HIDDEN_SIZE*2, 2).reshape(VOCAB_SIZE, HIDDEN_SIZE).astype(np.float32)
    decoder_linear_weight = np.random.rand(HIDDEN_SIZE, VOCAB_SIZE)
    # decoder_linear_weight = np.arange(0, HIDDEN_SIZE*VOCAB_SIZE*5, 5).reshape(HIDDEN_SIZE, VOCAB_SIZE).astype(np.float32)

    encoder_graph = create_encoder_graph(encoder_embedding_weight, decoder_embedding_weight, decoder_linear_weight)
    decoder_graph = create_decoder_graph(decoder_embedding_weight, decoder_linear_weight)

    # Create input tensor
    encoder_input = onnx.helper.make_tensor_value_info("encoder_input_ids", onnx.TensorProto.INT32, ["batch_size", "encode_sequence_length"])

    # Create output tensor
    sequences_output = onnx.helper.make_tensor_value_info("sequences", onnx.TensorProto.INT32, ["batch_size", 3, None])

    num_beams_tensor = onnx.helper.make_tensor("num_beams", onnx.TensorProto.INT32, [], [3])
    num_beams_as_constant = onnx.helper.make_node("Constant", [], ["num_beams"], value=num_beams_tensor)
    min_length_tensor = onnx.helper.make_tensor("min_length", onnx.TensorProto.INT32, [], [1])
    min_length_as_constant = onnx.helper.make_node("Constant", [], ["min_length"], value=min_length_tensor)
    max_length_tensor = onnx.helper.make_tensor("max_length", onnx.TensorProto.INT32, [], [10])
    max_length_as_constant = onnx.helper.make_node("Constant", [], ["max_length"], value=max_length_tensor)
    length_penalty_tensor = onnx.helper.make_tensor("length_penalty", onnx.TensorProto.FLOAT, [], [0.6])
    length_penalty_as_constant = onnx.helper.make_node("Constant", [], ["length_penalty"], value=length_penalty_tensor)

    # Create beam search node
    beam_search_node = onnx.helper.make_node(
        "BeamSearch",
        ["encoder_input_ids", "max_length", "min_length", "num_beams", "num_beams", "length_penalty"],
        ["sequences"],
        decoder=decoder_graph,
        encoder=encoder_graph,
        decoder_start_token_id=2,
        early_stopping=0,
        eos_token_id=2,
        model_type=1,
        pad_token_id=1,
        name="beam_search",
        domain="com.microsoft"
    )

    # Create main graph
    graph = onnx.helper.make_graph(
        nodes=[beam_search_node, num_beams_as_constant, min_length_as_constant, max_length_as_constant, length_penalty_as_constant],
        name="model",
        inputs=[encoder_input],
        outputs=[sequences_output],
    )

    # Create model
    model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 17)])

    return model

def run_model_with_different_EPs():
    # Initialize sessions with CPU and GPU providers
    cpu_session = onnxruntime.InferenceSession("model.onnx", providers=["CPUExecutionProvider"])
    gpu_session = onnxruntime.InferenceSession("model.onnx", providers=["CUDAExecutionProvider"])

    # Create input
    input = np.random.randint(0, VOCAB_SIZE, (1, 5)).astype(np.int32)

    print("Input: ", input)

    # Run model with CPU provider
    cpu_output = cpu_session.run([], {"encoder_input_ids": input})

    # Run model with GPU provider
    gpu_output = gpu_session.run([], {"encoder_input_ids": input})

    # Test if outputs are the same
    np.testing.assert_equal(cpu_output, gpu_output)

    print("Output: ", cpu_output)

# Create model
# model = create_model_with_beam_search()

# Save model
# onnx.save(model, "model.onnx")

# Run model with different execution providers
run_model_with_different_EPs()

print("Model saved and tested successfully")

Urgency

No response

Platform

Linux

OS Version

Ubuntu 20.04.5 LTS

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

737eb48

ONNX Runtime API

Python

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

CUDA 11.8

@github-actions github-actions bot added ep:CUDA issues related to the CUDA execution provider model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. labels May 13, 2024
@amancini-N
Copy link
Author

Opened a PR fixing the issue: #20668

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant