Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't finetune 20B model from slim weights with zero optimizer enabled #926

Open
coreystatendet opened this issue May 5, 2023 · 2 comments
Labels
bug Something isn't working

Comments

@coreystatendet
Copy link

Describe the bug
When attempting to finetune (with no_load_optim=True) from slim weights, the deepspeed engine attempts to load files with the naming structure zero_pp_rank_X_mp_rank_Y_optim_states.pt. These files don't exist in the slim weights.

Here's a traceback:

  File "/workdir/megatron/training.py", line 187, in pretrain     model, optimizer, lr_scheduler = setup_model_and_optimizer(
  File "/workdir/megatron/training.py", line 638, in setup_model_and_optimizer     neox_args.iteration = load_checkpoint(
  File "/workdir/megatron/checkpointing.py", line 247, in load_checkpoint     checkpoint_name, state_dict = model.load_checkpoint(
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 2783, in load_checkpoint     success = self._load_zero_checkpoint(
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 2962, in _load_zero_checkpoint     zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag)
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 3056, in _get_all_zero_checkpoints     return self._get_all_zero_checkpoint_state_dicts(zero_ckpt_names)
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 3028, in _get_all_zero_checkpoint_state_dicts     _state = self.checkpoint_engine.load(
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py", line 24, in load     partition = torch.load(path, map_location=map_location)
  File "/opt/conda/lib/python3.8/site-packages/torch/serialization.py", line 594, in load     with _open_file_like(f, 'rb') as opened_file:
  File "/opt/conda/lib/python3.8/site-packages/torch/serialization.py", line 230, in _open_file_like     return _open_file(name_or_buffer, mode)
  File "/opt/conda/lib/python3.8/site-packages/torch/serialization.py", line 211, in __init__     super(_open_file, self).__init__(open(name, mode))
FileNotFoundError: [Errno 2] No such file or directory: '/shared_fs/20B_checkpoints/global_step150000/zero_pp_rank_0_mp_rank_03_optim_states.pt'

To Reproduce
Run any training with zero_optimization.stage > 0, finetune=True, no_load_optim=True, and load set to a location containing the slim weights from here: https://the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/.

Expected behavior
Weights to be loaded successfully, and training to continue.

Proposed solution
Some relevant code locations:

It looks like using universal checkpoints simply skips the zero optimizer loading, so it's perhaps not essential -- maybe a check for no_load_optim==True could be also be used to skip?

Screenshots
N/A

Environment (please complete the following information):

  • GPUs: 16xA100 80GB (across 2 nodes)
  • Configs:

Config 1:

{
  "pipe-parallel-size": 4,
  "model-parallel-size": 2,
  "attention_config": [[["flash"], 44]],
  "num-layers": 44,
  "hidden-size": 6144,
  "num-attention-heads": 64,
  "seq-length": 2048,
  "max-position-embeddings": 2048,
  "norm": "layernorm",
  "pos-emb": "rotary",
  "rotary_pct": 0.25,
  "no-weight-tying": true,
  "gpt_j_residual": true,
  "output_layer_parallelism": "column",
  "scaled-upper-triang-masked-softmax-fusion": true,
  "bias-gelu-fusion": true,
  "init_method": "small_init",
  "output_layer_init_method": "wang_init",
  "optimizer": {
    "type": "Adam",
    "params": {
      "lr": 0.97e-4,
      "betas": [0.9, 0.95],
      "eps": 1.0e-8,
      }
      },
  "min_lr": 0.97e-5,
  "zero_optimization": {
  "stage": 1,
  "allgather_partitions": True,
  "allgather_bucket_size": 1260000000,
  "overlap_comm": True,
  "reduce_scatter": True,
  "reduce_bucket_size": 1260000000,
  "contiguous_gradients": True,
  },
  "train_micro_batch_size_per_gpu": 4,
  "gradient-accumulation-steps": null,
  "data-impl": "mmap",
  "split": "995,4,1",
  "flops_profiler": {
  "enabled": true,
  "profile_step": 5,
  "module_depth": -1,
  "top_modules": 1,
  "detailed": true,
  "output_file": null,
  },
  "checkpoint-activations": false,
  "checkpoint-num-layers": 1,
  "partition-activations": false,
  "synchronize-each-layer": true,
  "gradient_clipping": 1.0,
  "weight-decay":   0.01,
  "hidden-dropout": 0,
  "attention-dropout": 0,
  "fp16": {
    "fp16": true,
    "enabled": true,
    "loss_scale": 0,
    "loss_scale_window": 1000,
    "initial_scale_power": 12,
    "hysteresis": 2,
    "min_loss_scale": 1
    },
  "train-iters": 1000,
  "lr-decay-iters": 1000,
  "distributed-backend": "nccl",
  "lr-decay-style": "cosine",
  "warmup": 0.01,
  "checkpoint-factor": 50,
  "eval-interval": 100,
  "eval-iters": 10,
  "log-interval": 2,
  "steps_per_print": 2,
  "wall_clock_breakdown": false,
  "tokenizer_type": "HFTokenizer",
  "tensorboard-dir": "./tensorboard",
  "log-dir": null
}

Config 2 (anonymized):

{
  "vocab-file": "/shared_fs/20B_checkpoints/20B_tokenizer.json",
  "load": "/shared_fs/20B_checkpoints",
  "data-path": "/shared_fs/data",
  "override_lr_scheduler": true,
}

overwrite_values passed to NeoXArgs.from_ymls: {"finetune": True, "no_load_optim": True}

Additional context
N/A

@coreystatendet coreystatendet added the bug Something isn't working label May 5, 2023
@StellaAthena
Copy link
Member

Can you check if this PR fixes your problem? #927

@StellaAthena
Copy link
Member

The above PR is now on main @coreystatendet

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants