Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix stopping_criteria result check in coca_model #860

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

MengqingCao
Copy link

fix #847

The stopping criteria is updated in the latest transformers(V4.39.3 now). The return result is modified to a tensor (torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool) ) instead of a bool value, which causes the bug in #847

the related code in transformers
https://github.com/huggingface/transformers/blob/0bd58f1ce0573c0e3269de4215a17d318add49b9/src/transformers/generation/stopping_criteria.py#L76

@MengqingCao
Copy link
Author

@rwightman
Copy link
Collaborator

rwightman commented May 9, 2024

@MengqingCao hmm, yeah this needs fixing. One q though as I'm not intimately familiar with the gen code. Does the implementation here support multiple sentences in a batch? If yes, using any() will stop too early as the bool tensor is done state for each row (sentence) no? If it's only supporting one sentence at a time this is fine for now..
@gpucce ?

@MengqingCao
Copy link
Author

@MengqingCao hmm, yeah this needs fixing. One q though as I'm not intimately familiar with the gen code. Does the implementation here support multiple sentences in a batch? If yes, using any() will stop too early as the bool tensor is done state for each row (sentence) no? If it's only supporting one sentence at a time this is fine for now.. @gpucce ?

Your concerns are right in the cases when users using StopStringCriteria and EosTokenCriteria, which I ignored before. I only noticed the default StoppingCriteria method MaxLengthCriteria before, which returns a boolTensor filled with one single bool value is_done. Thus, I think use any() brings bigger operating efficiency than all().

The related code in Transformers:
image

To adapt to the situation of StopStringCriteria and EosTokenCriteria at the same time, I think we have two choices:

  1. change to use all() here
  2. checking if there is StopStringCriteria and EosTokenCriteria in stopping_criteria, if no, use any(), otherwise, use all(). This may run faster but bring more changes than 1

@MengqingCao
Copy link
Author

@rwightman @gpucce , I have implemented option 2 and updated the code, give me some suggestions plz, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

RuntimeError: Boolean value of Tensor with more than one value is ambiguous
2 participants