We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Running a model with LayerNormalization op brings wrong result when all the following conditions apply:
The following graph is reproducing the issue:
This appears to be caused by a side-effect done by the Dnnl kernel of LayerNormalization on the input
Following snippet builds the minimal example model and compares result between CPU and oneDNN EPs:
import onnx import onnx.helper import numpy as np import onnxruntime def create_layer_norm_model(): # Define input and output names input_name = 'input' output_name = 'output' # Define input shape input_shape = ('batch_size', 3) # Create input tensor input_tensor = onnx.helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape) # Constant scale scale = onnx.helper.make_node( 'Constant', inputs=[], outputs=['scale'], value=onnx.numpy_helper.from_array(np.ones((3,), dtype=np.float32)) ) # Constant bias (zeros) bias = onnx.helper.make_node( 'Constant', inputs=[], outputs=['bias'], value=onnx.numpy_helper.from_array(np.zeros((3,), dtype=np.float32)) ) # Create layer normalization node layer_norm_node = onnx.helper.make_node( 'LayerNormalization', inputs=[input_name, 'scale', 'bias'], outputs=[output_name], epsilon=1e-5 ) # Create addition node add_node = onnx.helper.make_node( 'Add', inputs=[input_name, output_name], outputs=[output_name + '_add'] ) # Create output tensor output_tensor = onnx.helper.make_tensor_value_info(output_name + '_add', onnx.TensorProto.FLOAT, input_shape) # Create graph graph_def = onnx.helper.make_graph( nodes=[layer_norm_node, add_node, scale, bias], name='layer_norm_model', inputs=[input_tensor], outputs=[output_tensor] ) # Create model return onnx.helper.make_model(graph_def, producer_name='onnx-example', doc_string="A model with input x and output y = LayerNormalization(x) + x") def test_model_against_cpu_and_dnnl_eps(model): # Generate dummy input input_data = np.random.rand(3, 3).astype(np.float32) # Create ONNX runtime session with different execution providers sess_cpu = onnxruntime.InferenceSession(model, providers=['CPUExecutionProvider']) sess_dnnl = onnxruntime.InferenceSession(model, providers=['DnnlExecutionProvider']) # Run the model on both providers output_cpu = sess_cpu.run(["output_add"], {'input': input_data}) output_dnnl = sess_dnnl.run(["output_add"], {'input': input_data}) # Check if the outputs are the same np.testing.assert_allclose(output_cpu, output_dnnl, rtol=1e-03, atol=1e-05) # Create the ONNX model model = create_layer_norm_model() # Save the model to a file onnx.save(model, 'layer_norm_model.onnx') # Test the model against CPU and DNNL execution providers test_model_against_cpu_and_dnnl_eps('layer_norm_model.onnx')
No response
Linux
Ubuntu 20.04.5 LTS
Built from Source
737eb48
Python
X64
oneDNN
The text was updated successfully, but these errors were encountered:
Opened a PR fixing the issue: #20624
Sorry, something went wrong.
Successfully merging a pull request may close this issue.
Describe the issue
Running a model with LayerNormalization op brings wrong result when all the following conditions apply:
The following graph is reproducing the issue:
This appears to be caused by a side-effect done by the Dnnl kernel of LayerNormalization on the input
To reproduce
Following snippet builds the minimal example model and compares result between CPU and oneDNN EPs:
Urgency
No response
Platform
Linux
OS Version
Ubuntu 20.04.5 LTS
ONNX Runtime Installation
Built from Source
ONNX Runtime Version or Commit ID
737eb48
ONNX Runtime API
Python
Architecture
X64
Execution Provider
oneDNN
Execution Provider Library Version
No response
The text was updated successfully, but these errors were encountered: