Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaN results for PyTorch models on versions > 6.3.0 #2223

Closed
hjrnunes opened this issue May 17, 2024 · 3 comments
Closed

NaN results for PyTorch models on versions > 6.3.0 #2223

hjrnunes opened this issue May 17, 2024 · 3 comments
Labels
bug Unexpected behaviour that should be corrected (type)

Comments

@hjrnunes
Copy link

馃悶Describing the bug

Versions later than 6.3.0 i.e. 7.0b1 and later, produce NaN for results in some models. Version 6.3.0 works fine for the exact same models and conversion code.

To Reproduce

Run the following examples with version 6.3.0 and numbers will be produced. Run with e.g. 7.2 and NaN will be produced.

Example 1

import coremltools as ct
import numpy
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("amberoad/bert-multilingual-passage-reranking-msmarco")
base_model = AutoModelForSequenceClassification.from_pretrained("amberoad/bert-multilingual-passage-reranking-msmarco")

encoded_input = tokenizer('How many people live in Berlin?',
                          'Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.',
                          return_tensors='pt')


class ModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super(ModelWrapper, self).__init__()
        self.model = model

    def forward(self, input_ids, token_type_ids, attention_mask):
        with torch.no_grad():
            model_result = self.model(input_ids=input_ids,
                                      token_type_ids=token_type_ids,
                                      attention_mask=attention_mask,
                                      return_dict=True)

        return model_result.logits


wrapped_model = ModelWrapper(base_model)
traced_model = torch.jit.trace(wrapped_model.eval(),
                               (
                                   encoded_input['input_ids'],
                                   encoded_input['token_type_ids'],
                                   encoded_input['attention_mask']))
traced_model.eval()

inputs_shape = ct.Shape(shape=(1, ct.RangeDim(lower_bound=1, upper_bound=512)))

model_cm = ct.convert(
    traced_model,
    inputs=[
        ct.TensorType(name="input_ids", shape=inputs_shape, dtype=numpy.int32),
        ct.TensorType(name="token_type_ids", shape=inputs_shape, dtype=numpy.int32),
        ct.TensorType(name="attention_mask", shape=inputs_shape, dtype=numpy.int32)
    ],
    outputs=[ct.TensorType(name="logits")],
    convert_to="mlprogram",
)

x_cm = {k: v.to(torch.int32).numpy() for k, v in encoded_input.items()}
y_cm = model_cm.predict(x_cm)

print(y_cm)

Example 2

import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
import coremltools as ct

sentences = ["This is a test."]
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2', torchscript=True).eval()
encoded_input = tokenizer(sentences, return_tensors='pt')
input_values = tuple(encoded_input.values())

x = (encoded_input['input_ids'], encoded_input['attention_mask'], encoded_input['token_type_ids'])

traced_model = torch.jit.trace(model, x)

model_cm = ct.convert(
    traced_model, source="pytorch",
    inputs=[
        ct.TensorType(name="input_ids", shape=input_values[0].shape, dtype=np.int32),
        ct.TensorType(name="attention_mask", shape=input_values[2].shape, dtype=np.int32),
        ct.TensorType(name="token_type_ids", shape=input_values[1].shape, dtype=np.int32)
    ],
    outputs=[
        ct.TensorType(name="out"),
        ct.TensorType(name="hidden_states"),
    ],
    compute_units=ct.ComputeUnit.CPU_ONLY,
    convert_to="mlprogram",  # convert_to="mlprogram",
)

out_t, hidden_states_t = traced_model(*x)

x_cm = {k: v.to(torch.int32).numpy() for k, v in encoded_input.items()}
y_cm = model_cm.predict(x_cm)

print(y_cm)

Related issue: #1809

Note that this is the same code as in #1809, but I have changed covert_to back to mlprogram.

As the earlier report's version is 6.1, I am not sure whether mlprogram has started working on 6.3.0 and then stopped working again later or something else is going on. The fact is it does work on 6.3.0 but does not work on 7.2.

Other potential related issues

#1932

This earlier report by @ZachNagengast provided resolution for my case (thanks!).
Even though the symptoms are slightly different in that inference does work (i.e. does not return NaN) on Python, given that the solution is the same, I'd wager it is the same issue.

#2166

I came across this in my research, but have not reproduced this as I couldn't be bothered bypassing the authentication wall, but looking at the 7.1 version, I'd wager it is actually the same issue.

System environment (please complete the following information):

  • coremltools version: 6.3.0 and 7.2
  • OS (e.g. MacOS version or Linux type): macOS 13.6.6
  • Any other relevant version information (e.g. PyTorch or TensorFlow version): PyTorch 2.2.0
@hjrnunes hjrnunes added the bug Unexpected behaviour that should be corrected (type) label May 17, 2024
@YifanShenSZ
Copy link
Collaborator

YifanShenSZ commented May 20, 2024

Hi @hjrnunes, with coremltools 7.2, could you please try compute_precision=compute_precision=ct.precision.FLOAT32?

Concretely, when I read #1932, the culprit seems to be an infinity in clip op, that is trying to clip huge fp32 value to 3.4e38. This 3.4e38 is unrepresentable in fp16, so it becomes inf if we use fp16 compute precision.

@hjrnunes
Copy link
Author

Thanks, @YifanShenSZ

That does seem to fix both examples.

@hjrnunes
Copy link
Author

Update for anyone landing here. These are the fixed examples that do work with 7.2:

Example 1

import coremltools as ct
import numpy
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("amberoad/bert-multilingual-passage-reranking-msmarco")
base_model = AutoModelForSequenceClassification.from_pretrained("amberoad/bert-multilingual-passage-reranking-msmarco")

encoded_input = tokenizer('How many people live in Berlin?',
                          'Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.',
                          return_tensors='pt')


class ModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super(ModelWrapper, self).__init__()
        self.model = model

    def forward(self, input_ids, token_type_ids, attention_mask):
        with torch.no_grad():
            model_result = self.model(input_ids=input_ids,
                                      token_type_ids=token_type_ids,
                                      attention_mask=attention_mask,
                                      return_dict=True)

        return model_result.logits


wrapped_model = ModelWrapper(base_model)
traced_model = torch.jit.trace(wrapped_model.eval(),
                               (
                                   encoded_input['input_ids'],
                                   encoded_input['token_type_ids'],
                                   encoded_input['attention_mask']))
traced_model.eval()

inputs_shape = ct.Shape(shape=(1, ct.RangeDim(lower_bound=1, upper_bound=512)))

model_cm = ct.convert(
    traced_model,
    inputs=[
        ct.TensorType(name="input_ids", shape=inputs_shape, dtype=numpy.int32),
        ct.TensorType(name="token_type_ids", shape=inputs_shape, dtype=numpy.int32),
        ct.TensorType(name="attention_mask", shape=inputs_shape, dtype=numpy.int32)
    ],
    outputs=[ct.TensorType(name="logits")],
    convert_to="mlprogram",
    compute_precision=ct.precision.FLOAT32 # <-- This was added
)

x_cm = {k: v.to(torch.int32).numpy() for k, v in encoded_input.items()}
y_cm = model_cm.predict(x_cm)

print(y_cm)

Example 2

import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
import coremltools as ct

sentences = ["This is a test."]
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2', torchscript=True).eval()
encoded_input = tokenizer(sentences, return_tensors='pt')
input_values = tuple(encoded_input.values())

x = (encoded_input['input_ids'], encoded_input['attention_mask'], encoded_input['token_type_ids'])

traced_model = torch.jit.trace(model, x)

model_cm = ct.convert(
    traced_model, source="pytorch",
    inputs=[
        ct.TensorType(name="input_ids", shape=input_values[0].shape, dtype=np.int32),
        ct.TensorType(name="attention_mask", shape=input_values[2].shape, dtype=np.int32),
        ct.TensorType(name="token_type_ids", shape=input_values[1].shape, dtype=np.int32)
    ],
    outputs=[
        ct.TensorType(name="out"),
        ct.TensorType(name="hidden_states"),
    ],
    compute_units=ct.ComputeUnit.CPU_ONLY,
    convert_to="mlprogram",
    compute_precision=ct.precision.FLOAT32 # <-- This was added    
)

out_t, hidden_states_t = traced_model(*x)

x_cm = {k: v.to(torch.int32).numpy() for k, v in encoded_input.items()}
y_cm = model_cm.predict(x_cm)

print(y_cm)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Unexpected behaviour that should be corrected (type)
Projects
None yet
Development

No branches or pull requests

2 participants