Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use_reentrant=False can't be set properly #30749

Open
2 of 4 tasks
getao opened this issue May 10, 2024 · 6 comments
Open
2 of 4 tasks

use_reentrant=False can't be set properly #30749

getao opened this issue May 10, 2024 · 6 comments

Comments

@getao
Copy link

getao commented May 10, 2024

System Info

transformers==4.40.1
deepspeed==0.14.2
torch==2.2.1

Who can help?

@ArthurZucker Hello, I used the tranformers' trainer with deepspeed to train the decoder-only model. As a common solution to reduce memory, I enabled gradient checkpointing and set use_reentrance to False in my code:

    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}

When I printed the training_args, it shows propoerly:

...
gradient_checkpointing=True,
gradient_checkpointing_kwargs={'use_reentrant': False},
...

The training_args is properly passed to Trainer with:

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=data_collator
    )

However, when training starts, I was still always warned with the message:

warnings.warn(
/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.

I don't understand the reason leading to the message but it seems that use_reentrant is not properly set to take effect.

Could anyone please help me take a look at the problem?

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

def train_model(model, train_dataset, eval_dataset, training_args, data_collator=None):

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=data_collator
    )

def main():
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}

    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)

    train_dataset, eval_dataset = load_data(...)

    train_dataset = train_dataset.map(
        group_texts,
        batched=True,
        fn_kwargs={"block_size": data_args.max_seq_length},
        load_from_cache_file=True
    )
    eval_dataset = eval_dataset.map(
        group_texts,
        batched=True,
        fn_kwargs={"block_size": data_args.max_seq_length},
        load_from_cache_file=True
    )

    model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, torch_dtype=torch.float16 if not training_args.bf16 else torch.bfloat16, resume_download=True, trust_remote_code=True, attn_implementation="flash_attention_2")

    train_model(model, train_dataset, eval_dataset, training_args)

The some configs of deepspeed are as follows:

    "zero_optimization": {
        "stage": 3,
        "allgather_partitions": true,
        "allgather_bucket_size": 5e7,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": 5e7,
        "contiguous_gradients" : true,
        "stage3_max_live_parameters" : 1e8,
        "stage3_max_reuse_distance" : 1e8,
        "stage3_prefetch_bucket_size" : 5e7,
        "stage3_param_persistence_threshold" : 1e7,
        "sub_group_size" : 1e9,
        "stage3_gather_16bit_weights_on_model_save": true
    },

    "activation_checkpointing": {
        "partition_activations": true,
        "contiguous_memory_optimization": false,
        "cpu_checkpointing": false,
        "number_checkpoints": null,
        "synchronize_checkpoint_boundary": false,
        "profile": false
    },

Expected behavior

No warning message

@cw235
Copy link

cw235 commented May 10, 2024

Hi there! It seems like you are encountering a warning related to the use_reentrant parameter when training your model with deepspeed and transformers. The warning is advising you to explicitly pass in use_reentrant=True or use_reentrant=False to the torch checkpoint function.

In your code snippet, you have set use_reentrant to False in the gradient_checkpointing_kwargs, but the warning indicates that it needs to be explicitly passed when using torch's checkpointing mechanism.

To address this warning and ensure that use_reentrant is properly set, you can explicitly pass this argument to the Trainer constructor where you instantiate the trainer, like so:

checkpointing_args = {"use_reentrant": False}

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs=checkpointing_args
)

By explicitly passing gradient_checkpointing=True and gradient_checkpointing_kwargs=checkpointing_args, you are ensuring that the use_reentrant argument is set correctly and should prevent the warning message from appearing during training.

If you have any further questions or if the issue persists, feel free to ask for more assistance!

@getao
Copy link
Author

getao commented May 11, 2024

Hi there! It seems like you are encountering a warning related to the use_reentrant parameter when training your model with deepspeed and transformers. The warning is advising you to explicitly pass in use_reentrant=True or use_reentrant=False to the torch checkpoint function.

In your code snippet, you have set use_reentrant to False in the gradient_checkpointing_kwargs, but the warning indicates that it needs to be explicitly passed when using torch's checkpointing mechanism.

To address this warning and ensure that use_reentrant is properly set, you can explicitly pass this argument to the Trainer constructor where you instantiate the trainer, like so:

checkpointing_args = {"use_reentrant": False}

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs=checkpointing_args
)

By explicitly passing gradient_checkpointing=True and gradient_checkpointing_kwargs=checkpointing_args, you are ensuring that the use_reentrant argument is set correctly and should prevent the warning message from appearing during training.

If you have any further questions or if the issue persists, feel free to ask for more assistance!

TypeError: Trainer.init() got an unexpected keyword argument 'gradient_checkpointing_kwargs'

It doesn't work.

BTW, are you a model?

@amyeroberts
Copy link
Collaborator

@cw235 There are a series of issues you've commented on e.g. here, here and here in which the issue has clearly been fed through to a chat model e.g. ChatGPT to produce an output which does not answer or address the question. Please refrain from doing this, it is both unhelpful to the person reporting the issue and isn't scalable behaviour: what would happen if everyone started doing this?

All of these will be marked as spam so as to hide them and make the issues navigable. If you keep doing this, the account @cw235 will be reported.

@amyeroberts
Copy link
Collaborator

cc @pacman100 @muellerzr

@muellerzr
Copy link
Contributor

I couldn't recreate this warning. Can you try either:

  1. Using the dev version of transformers pip install git+https://github.com/huggingface/transformers
  2. Upgrading your PyTorch version pip install torch -U

And let us know if you get this warning?

Thanks!

@getao
Copy link
Author

getao commented May 13, 2024

I couldn't recreate this warning. Can you try either:

  1. Using the dev version of transformers pip install git+https://github.com/huggingface/transformers
  2. Upgrading your PyTorch version pip install torch -U

And let us know if you get this warning?

Thanks!

Thank you. Is it possible to be related to specific models with the argument trust_remote_code=True?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants