Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resuming training fails #105

Open
hidoba opened this issue Mar 31, 2024 · 3 comments
Open

Resuming training fails #105

hidoba opened this issue Mar 31, 2024 · 3 comments

Comments

@hidoba
Copy link

hidoba commented Mar 31, 2024

So I removed the flag --overwrite_output_dir to be able to resume the training, and I'm getting the following error:

04/01/2024 00:30:01 - INFO - __main__ - max_steps is given, it will override any value given in num_train_epochs
04/01/2024 00:30:04 - INFO - __main__ - ***** Running training *****
04/01/2024 00:30:04 - INFO - __main__ -   Num examples = 4800000
04/01/2024 00:30:04 - INFO - __main__ -   Instantaneous batch size per device = 8
04/01/2024 00:30:04 - INFO - __main__ -   Gradient accumulation steps = 1
04/01/2024 00:30:04 - INFO - __main__ -   Total train batch size (w. parallel & distributed) = 8
04/01/2024 00:30:04 - INFO - __main__ -   Total optimization steps = 600000
Train steps ... :   0%|                                                                               | 0/600000 [00:00<?, ?it/s]04/01/2024 00:30:04 - INFO - accelerate.accelerator - Loading states from ./checkpoint-5000-epoch-0
Traceback (most recent call last):
  File "/home/vlad/distil-whisper/training/run_distillation.py", line 1682, in <module>
    main()
  File "/home/vlad/distil-whisper/training/run_distillation.py", line 1484, in main
    accelerator.load_state(checkpoint)
  File "/home/vlad/distil-whisper/.venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 2966, in load_state
    load_accelerator_state(
  File "/home/vlad/distil-whisper/.venv/lib/python3.10/site-packages/accelerate/checkpointing.py", line 205, in load_accelerator_state
    models[i].load_state_dict(state_dict, **load_model_func_kwargs)
  File "/home/vlad/distil-whisper/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2153, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for WhisperForConditionalGeneration:
        Missing key(s) in state_dict: "proj_out.weight". 

At the same time, evaluation script works just fine with the same checkpoint.

I'm using Ubuntu 22, rtx 3090 ti.

@hidoba
Copy link
Author

hidoba commented Mar 31, 2024

I've also observed this in the log:

04/01/2024 00:35:47 - WARNING - accelerate.utils.other - Removed shared tensor {'proj_out.weight'} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading

@Gusreis7
Copy link

any updates on this ? I'm facing the same problem

@George0828Zhang
Copy link

George0828Zhang commented Apr 30, 2024

Here's a temporary fix according to https://huggingface.co/docs/safetensors/torch_shared_tensors

Modify load_accelerator_state(): https://github.com/huggingface/accelerate/blob/main/src/accelerate/checkpointing.py#L153

-from safetensors.torch import load_file
+from safetensors.torch import load_model
...
    if input_model_file.exists():
-       state_dict = load_file(input_model_file, device=str(map_location))
+       load_model(models[i], input_model_file, device=str(map_location), **load_model_func_kwargs)
    else:
        # Load with torch
        input_model_file = input_dir.joinpath(f"{MODEL_NAME}{ending}.bin")
        state_dict = torch.load(input_model_file, map_location=map_location)
-   models[i].load_state_dict(state_dict, **load_model_func_kwargs)
+       models[i].load_state_dict(state_dict, **load_model_func_kwargs)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants