Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2.3 - If use_reentrant is not explicitly passed, an exception will now be raised #637

Closed
apachemycat opened this issue May 3, 2024 · 4 comments

Comments

@apachemycat
Copy link

Pytorch 2.3 - If use_reentrant is not explicitly passed, an exception will now be raised

Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/xtuner/tools/train.py", line 360, in
main()
File "/usr/local/lib/python3.10/dist-packages/xtuner/tools/train.py", line 356, in main
runner.train()
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/runner.py", line 1777, in train
model = self.train_loop.run() # type: ignore
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/loops.py", line 287, in run
self.run_iter(data_batch)
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/loops.py", line 311, in run_iter
outputs = self.runner.model.train_step(
File "/usr/local/lib/python3.10/dist-packages/mmengine/model/wrappers/distributed.py", line 121, in train_step
losses = self._run_forward(data, mode='loss')
File "/usr/local/lib/python3.10/dist-packages/mmengine/model/wrappers/distributed.py", line 161, in _run_forward
results = self(**data, mode=mode)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1523, in forward
else self._run_ddp_forward(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1359, in _run_ddp_forward
return self.module(*inputs, **kwargs) # type: ignore[index]
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/xtuner/model/sft.py", line 228, in forward
return self.compute_loss(data, data_samples)
File "/usr/local/lib/python3.10/dist-packages/xtuner/model/sft.py", line 272, in compute_loss
return self._compute_sequence_parallel_loss(data)
File "/usr/local/lib/python3.10/dist-packages/xtuner/model/sft.py", line 262, in _compute_sequence_parallel_loss
outputs = self.llm(**data)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/peft/peft_model.py", line 1395, in forward
return self.base_model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/peft/tuners/tuners_utils.py", line 179, in forward
return self.model.forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 1205, in forward
outputs = self.model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 989, in forward
layer_outputs = self._gradient_checkpointing_func(
File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 417, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 25, in inner
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 460, in checkpoint
raise ValueError(
ValueError: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.

@LZHgrla
Copy link
Collaborator

LZHgrla commented May 6, 2024

@apachemycat Hi!

I checked the v2.3.0 code of pytorch, and found that it should raise a warning, not an error.

So could you please provide the pytorch version and the installation commands?

https://github.com/pytorch/pytorch/blob/v2.3.0/torch/utils/checkpoint.py#L463

@apachemycat
Copy link
Author

maybe I use 2.4 snapshort version ,but how to set this param ?I tried and failed ..

@LZHgrla
Copy link
Collaborator

LZHgrla commented May 7, 2024

@apachemycat

transformers will automatically set use_reentrant to True for the built-in model of transformers, huggingface/transformers#28538.

This means that the built-in models have already be passed within this parameter. However, the custom models are not set within this parameter (and we cannot pass in this parameter either), like https://huggingface.co/THUDM/chatglm3-6b-base/blob/f91a1de587fdc692073367198e65369669a0b49d/modeling_chatglm.py#L632

Therefore, the easiest way to solve this issue is to de-grade your pytorch version.

@apachemycat
Copy link
Author

thanks

@pppppM pppppM closed this as completed May 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants