Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

recent version of Transformers seems to mess with forward/__call__. Breaks patching loss function #30753

Open
3 of 4 tasks
grahamannett opened this issue May 10, 2024 · 6 comments

Comments

@grahamannett
Copy link

grahamannett commented May 10, 2024

System Info

$ python --version
Python 3.11.6

$ pip show transformers
Name: transformers
Version: 4.40.1

I updated to recent version of transformers for various models/bugs and believe something is happening from transformers that is breaking the ability to patch/wrap a forward that takes in labels. I am completely at a loss where it could be happening but seems like it is for many different models that either seem to use transformers.modeling_utils.PreTrainedModel, transformers.modeling_utils.ModuleUtilsMixin, transformers.integrations.peft.PeftAdapterMixin. A similar but dumber example without transformers I tried seems like it is not having this issue.

This code is a weird (in the sense that I am not sure if this is the best way to do something like this) and simplified reproduction but the general idea is that if I pass in labels, I would not want to pass the labels to the forward of the original model (for instance if you want to pass kwargs for weights/reduction/etc to CrossEntropyLoss). If I pass in labels to forward and just overwrite the outputs.loss value, it then works but there are various reasons you may not way to do this (its computing the loss 2x, the labels may be intended for a different loss function that wont work with the original model loss, etc)

Including a minimal reproduction below and didnt see any other recent issues that seem to say similar. Am also really hoping this isnt user error but I believe something like this worked fine when I had a transformers version from a few months ago.

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Running this, you would expect the loss to be a single tensor value

import torch
import torch.nn as nn
from transformers import LlamaForCausalLM


class WrappedForward(nn.Module):
    def __init__(self, model_forward, vocab_size):
        super().__init__()
        self.model_forward = model_forward
        self.vocab_size = vocab_size

    def loss_func(self, logits, labels):
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()

        shift_logits = shift_logits.view(-1, self.vocab_size)
        shift_labels = shift_labels.view(-1)
        shift_labels = shift_labels.to(shift_logits.device)

        loss_fct = nn.CrossEntropyLoss(reduction="mean")
        loss = loss_fct(shift_logits, shift_labels)
        return loss

    def forward(self, **kwargs):
        if (labels := kwargs.get("labels", None)) is not None:
            kwargs["labels"] = None

        output = self.model_forward(**kwargs)

        if labels is not None:
            output.loss = self.loss_func(output.logits, labels)

        return output


class ModelExample(LlamaForCausalLM):
    def __init__(self, config, *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        # save original forward for proving the concept
        self.original_forward = self.forward
        # wrapped forward is just a module/layer that will be called to remove labels from kwargs
        self.wrapped_forward = WrappedForward(self.original_forward, config.vocab_size)

        self.forward = self.wrapped_forward.forward


if __name__ == "__main__":
    model = ModelExample.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto")
    input_ids = torch.tensor([[1, 2, 3, 4, 5]], device=model.device)
    labels = torch.tensor([[2, 3, 4, 5, 6]], device=model.device)
    model.train()

    output = model.original_forward(input_ids=input_ids, labels=labels)
    assert isinstance(output.loss, torch.Tensor), "This will work fine"

    # Perform forward pass
    output = model(input_ids=input_ids, labels=labels)
    assert isinstance(output.loss, torch.Tensor), "This will fail with Transformers >=4.40.1"

Expected behavior

outputs.loss of the model should be similar to the original outputs.loss which is a single tensor value

@ArthurZucker
Copy link
Collaborator

Hey! Thanks for opening the issue.
I can't seems to be able to reproduce, this produces a tensor for me on '4.41.0' so either it was fixed, either it's related to accelerate`, which is the only thing that can affect all models at once for such a thing.
Now, I would probably write the loss and the forward in the ModelExample rather than create a wrapped class but that's personal taste!

@grahamannett
Copy link
Author

Hi @ArthurZucker I don't see 4.41.0 on pypi or in this repo (releases/tags), is there somewhere else to look or do you just mean install from source?

And just tried again on 4.40.1 and 4.40.2 with fresh env/installs and still see the error (from pip install "transformers[torch]>=4.40.2"). Here is the requirements.txt but can provide a pyproject/docker image if someone else can verify that this isn't just me:

accelerate==0.30.1
certifi==2024.2.2
charset-normalizer==3.3.2
filelock==3.14.0
fsspec==2024.3.1
huggingface-hub==0.23.0
idna==3.7
Jinja2==3.1.4
MarkupSafe==2.1.5
mpmath==1.3.0
networkx==3.3
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.1.105
packaging==24.0
psutil==5.9.8
PyYAML==6.0.1
regex==2024.5.10
requests==2.31.0
safetensors==0.4.3
sympy==1.12
tokenizers==0.19.1
torch==2.3.0
tqdm==4.66.4
transformers==4.40.2
triton==2.3.0
typing_extensions==4.11.0
urllib3==2.2.1

@grahamannett
Copy link
Author

Just to make sure, I just installed from source (4.41.0.dev0) and still get the error

@ArthurZucker
Copy link
Collaborator

I am unable to reproduce. Can you share a google colab with this? 🤗

@grahamannett
Copy link
Author

Seems like it is specific to using device_map="auto" and the model being on 2+ devices, are you using a single large GPU? Guessing that probably means it is related to accelerate which you mentioned as well so theres nothing really to do about it in transformers?

@ArthurZucker
Copy link
Collaborator

Yep, I was running on a single device.
@SunMarc and @muellerzr forcing two devices with balanced could help repro?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants