Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generation / FIX: Fix multi-device generation #30746

Merged
merged 11 commits into from May 13, 2024

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented May 10, 2024

What does this PR do?

Fixes failing tests for multi-device (e.g. Multi-GPU, GPU + CPU etc) generation. The fix is simply to make sure pad_token_id and all other special tokens are initialized on the correct device (e.g. for models offloaded on CPU self.device return "meta" which breaks the generation after 馃槩 )

cc @gante @ArthurZucker

@younesbelkada younesbelkada marked this pull request as ready for review May 10, 2024 16:41
@younesbelkada
Copy link
Contributor Author

The fix is to initialize the special tokens on the correct devices all the time, I updated the description of the PR

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker
Copy link
Collaborator

cc @gante !

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a suggestion to enable this change on all modalities!

Comment on lines 1522 to 1526
device = None
if "input_ids" in model_kwargs and isinstance(model_kwargs["input_ids"], torch.Tensor):
device = model_kwargs["input_ids"].device

self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I get it right: the device comes from the main model input, and not from the model itself.

Assuming what I wrote above is correct, we should get the device variable after the _prepare_model_inputs call, which extracts the main model input from the different keywords we might see (for instance, Whisper does not use input_ids). In that case, I would move these lines to after L1532 (currently batch_size = inputs_tensor.shape[0]), and use device=inputs_tensor.device :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes totally sense! Done!

@younesbelkada younesbelkada requested a review from gante May 13, 2024 10:02
@younesbelkada younesbelkada changed the title Generation / FIX: Attempt to fix multi-device generation Generation / FIX: Fix multi-device generation May 13, 2024
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perfect, thank you for iterating 馃憣

@younesbelkada
Copy link
Contributor Author

Thanks ! cc @ArthurZucker for the final review

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing. A small test is welcome (instead of the slow one!) to make sure we catch this earlier!

@@ -476,6 +476,7 @@ def _prepare_attention_mask_for_generation(
)
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
attention_mask_from_padding = inputs.ne(pad_token_id).long()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weird that this is changed 馃槃

@younesbelkada
Copy link
Contributor Author

thanks ! I don't think we can add tests as they would require a GPU, this is implictly tested through our models + quantization slow tests, hence how I catched the bug

@younesbelkada younesbelkada merged commit f823fec into huggingface:main May 13, 2024
21 checks passed
@younesbelkada younesbelkada deleted the fix-multi-gpu-bnb branch May 13, 2024 12:35
@ArthurZucker
Copy link
Collaborator

ok if there is no way to repro with a minimal trick putting the weights on meta device voluntarily!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants