Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The plot got from muP coord_check seems not horizontal, which may indicates there exits a bug in the muP implementation? #956

Open
BaoYu0721 opened this issue May 28, 2023 · 11 comments
Assignees
Labels
bug Something isn't working

Comments

@BaoYu0721
Copy link

BaoYu0721 commented May 28, 2023

Bug Discription & To Reproduce
The source code is from current main branch, and follow the instructions in the README-MUP.md until this step:
image
I encounter an error like this:

╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│ /mnt/cache/baoyu/gpt-neox/train.py:27 in <module>                            │
│                                                                              │
│   24 │   neox_args.configure_distributed_args()                              │
│   25 │   neox_args.build_tokenizer()  # tokenizer needs to be build in train │
│   26 │   neox_args.initialize_tensorboard_writer()  # is initialized if tens │
│ ❱ 27 │   pretrain(neox_args=neox_args)                                       │
│   28                                                                         │
│                                                                              │
│ /mnt/cache/baoyu/gpt-neox/megatron/training.py:211 in pretrain               │
│                                                                              │
│   208 │   timers("train/valid/test data iterators").stop()                   │
│   209 │                                                                      │
│   210 │   if neox_args.use_mup and neox_args.coord_check:                    │
│ ❱ 211 │   │   mup_coord_check(neox_args, timers, lr_scheduler, train_data_it │
│   212 │                                                                      │
│   213 │   # Print setup timing.                                              │
│   214 │   print_rank_0("done with setups ...")                               │
│                                                                              │
│ /mnt/cache/baoyu/gpt-neox/megatron/training.py:154 in mup_coord_check        │
│                                                                              │
│   151 │   │   models[hidden_size] = lazy_model(hidden_size)                  │
│   152 │                                                                      │
│   153 │   neox_args.use_mup = True                                           │
│ ❱ 154 │   df_up = get_coord_data(                                            │
│   155 │   │   neox_args, timers, lr_scheduler, models, train_data_iterator,  │
│   156 │   )                                                                  │
│   157 │   neox_args.use_mup = False                                          │
│                                                                              │
│ /mnt/cache/baoyu/gpt-neox/megatron/mup_substitute.py:207 in get_coord_data   │
│                                                                              │
│   204 │   elif optimizer is None:                                            │
│   205 │   │   raise ValueError("optimizer should be sgd|adam|adamw or a cust │
│   206 │                                                                      │
│ ❱ 207 │   data = _get_coord_data(                                            │
│   208 │   │   neox_args, timers, lr_scheduler, models, dataloader, optcls, * │
│   209 │   )                                                                  │
│   210 │   data["optimizer"] = optimizer                                      │
│                                                                              │
│ /mnt/cache/baoyu/gpt-neox/megatron/mup_substitute.py:69 in _get_coord_data   │
│                                                                              │
│    66 │   │   │   │   │   )                                                  │
│    67 │   │   │   │                                                          │
│    68 │   │   │   │   # train for a step                                     │
│ ❱  69 │   │   │   │   loss_dict, skipped_iter = train_step(                  │
│    70 │   │   │   │   │   neox_args=neox_args,                               │
│    71 │   │   │   │   │   timers=timers,                                     │
│    72 │   │   │   │   │   data_iterator=dataloader,                          │
│                                                                              │
│ /mnt/cache/baoyu/gpt-neox/megatron/training.py:695 in train_step             │
│                                                                              │
│   692 │                                                                      │
│   693 │   # Pipeline parallelism schedules forward/backward/step             │
│   694 │   if neox_args.is_pipe_parallel:                                     │
│ ❱ 695 │   │   reduced_loss = train_step_pipe(                                │
│   696 │   │   │   neox_args=neox_args, timers=timers, model=model, data_iter │
│   697 │   │   )                                                              │
│   698 │   else:                                                              │
│                                                                              │
│ /mnt/cache/baoyu/gpt-neox/megatron/training.py:745 in train_step_pipe        │
│                                                                              │
│   742 │   """Single training step with DeepSpeed's pipeline parallel engine. │
│   743 │                                                                      │
│   744 │   assert neox_args.deepspeed                                         │
│ ❱ 745 │   loss = model.train_batch(data_iter=data_iterator)                  │
│   746 │   loss_dict = {"lm_loss": loss}                                      │
│   747 │   # Don't break Megatron's timers because we changed code paths.     │
│   748 │   for t in [                                                         │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/deepspeed/runtime/ │
│ pipe/engine.py:336 in train_batch                                            │
│                                                                              │
│    333 │   │   sched = schedule.TrainSchedule(micro_batches=self.micro_batch │
│    334 │   │   │   │   │   │   │   │   │      stages=self.num_stages,        │
│    335 │   │   │   │   │   │   │   │   │      stage_id=self.stage_id)        │
│ ❱  336 │   │   self._exec_schedule(sched)                                    │
│    337 │   │   self.agg_train_loss = self._aggregate_total_loss()            │
│    338 │   │                                                                 │
│    339 │   │   self.timers('train_batch').stop()                             │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/deepspeed/runtime/ │
│ pipe/engine.py:1307 in _exec_schedule                                        │
│                                                                              │
│   1304 │   │   │   │                                                         │
│   1305 │   │   │   │   # Equivalent to: self._exec_forward_pass(buffer_id=0) │
│   1306 │   │   │   │   self._exec_instr = MethodType(self._INSTRUCTION_MAP[t │
│ ❱ 1307 │   │   │   │   self._exec_instr(**cmd.kwargs)                        │
│   1308                                                                       │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/deepspeed/runtime/ │
│ pipe/engine.py:627 in _exec_forward_pass                                     │
│                                                                              │
│    624 │   │   # tensor changes across batches                               │
│    625 │   │   self._zero_grads(inputs)                                      │
│    626 │   │                                                                 │
│ ❱  627 │   │   outputs = super().forward(inputs)                             │
│    628 │   │                                                                 │
│    629 │   │   # Reset activation checkpointing buffers.                     │
│    630 │   │   # Need to call this between evaluation iterations             │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/deepspeed/utils/nv │
│ tx.py:15 in wrapped_fn                                                       │
│                                                                              │
│   12 │                                                                       │
│   13 │   def wrapped_fn(*args, **kwargs):                                    │
│   14 │   │   get_accelerator().range_push(func.__qualname__)                 │
│ ❱ 15 │   │   ret_val = func(*args, **kwargs)                                 │
│   16 │   │   get_accelerator().range_pop()                                   │
│   17 │   │   return ret_val                                                  │
│   18                                                                         │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/deepspeed/runtime/ │
│ engine.py:1731 in forward                                                    │
│                                                                              │
│   1728 │   │   if self.fp16_auto_cast():                                     │
│   1729 │   │   │   inputs = self._cast_inputs_half(inputs)                   │
│   1730 │   │                                                                 │
│ ❱ 1731 │   │   loss = self.module(*inputs, **kwargs)                         │
│   1732 │   │                                                                 │
│   1733 │   │   if self.zero_optimization_partition_weights():                │
│   1734 │   │   │   # Disable automated discovery of external parameters      │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/torch/nn/modules/m │
│ odule.py:1212 in _call_impl                                                  │
│                                                                              │
│   1209 │   │   │   bw_hook = hooks.BackwardHook(self, full_backward_hooks)   │
│   1210 │   │   │   input = bw_hook.setup_input_hook(input)                   │
│   1211 │   │                                                                 │
│ ❱ 1212 │   │   result = forward_call(*input, **kwargs)                       │
│   1213 │   │   if _global_forward_hooks or self._forward_hooks:              │
│   1214 │   │   │   for hook in (*_global_forward_hooks.values(), *self._forw │
│   1215 │   │   │   │   hook_result = hook(self, input, result)               │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/deepspeed/runtime/ │
│ pipe/module.py:350 in forward                                                │
│                                                                              │
│   347 │   │   │   │   if self._is_checkpointable(funcs):                     │
│   348 │   │   │   │   │   x = self.activation_checkpoint_func(exec_range_fun │
│   349 │   │   │   │   else:                                                  │
│ ❱ 350 │   │   │   │   │   x = exec_range_func(start_idx, end_idx)(*x)        │
│   351 │   │   return x                                                       │
│   352 │                                                                      │
│   353 │   def _partition_layers(self, method='uniform'):                     │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/deepspeed/runtime/ │
│ pipe/module.py:327 in exec_func                                              │
│                                                                              │
│   324 │   │   │   │   │   │   else:                                          │
│   325 │   │   │   │   │   │   │   ds_utils.set_random_seed(new_seed)         │
│   326 │   │   │   │   │                                                      │
│ ❱ 327 │   │   │   │   │   inputs = layer(inputs)                             │
│   328 │   │   │   │   return inputs                                          │
│   329 │   │   │                                                              │
│   330 │   │   │   return exec_func                                           │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/torch/nn/modules/m │
│ odule.py:1215 in _call_impl                                                  │
│                                                                              │
│   1212 │   │   result = forward_call(*input, **kwargs)                       │
│   1213 │   │   if _global_forward_hooks or self._forward_hooks:              │
│   1214 │   │   │   for hook in (*_global_forward_hooks.values(), *self._forw │
│ ❱ 1215 │   │   │   │   hook_result = hook(self, input, result)               │
│   1216 │   │   │   │   if hook_result is not None:                           │
│   1217 │   │   │   │   │   result = hook_result                              │
│   1218                                                                       │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/mup/coord_check.py │
│ :161 in f                                                                    │
│                                                                              │
│   158 │   │   │   │   for i, out in enumerate(output):                       │
│   159 │   │   │   │   │   _ret = copy(ret)                                   │
│   160 │   │   │   │   │   _ret['module'] += f':out[{i}]'                     │
│ ❱ 161 │   │   │   │   │   get_stat(_ret, out, output_fdict)                  │
│   162 │   │   │   elif isinstance(output, dict):                             │
│   163 │   │   │   │   for name, out in output.items():                       │
│   164 │   │   │   │   │   _ret = copy(ret)                                   │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/mup/coord_check.py │
│ :145 in get_stat                                                             │
│                                                                              │
│   142 │   │   │   elif isinstance(x, torch.Tensor):                          │
│   143 │   │   │   │   _d = copy(d)                                           │
│   144 │   │   │   │   for fname, f in fdict.items():                         │
│ ❱ 145 │   │   │   │   │   _d[fname] = f(x).item()                            │
│   146 │   │   │   │   records.append(_d)                                     │
│   147 │   │   │   else:                                                      │
│   148 │   │   │   │   raise NotImplemented(f'Unexpected output type: {type(x │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/mup/coord_check.py │
│ :44 in <lambda>                                                              │
│                                                                              │
│    41                                                                        │
│    42 #: dict of provided functions for use in coord check                   │
│    43 FDICT = {                                                              │
│ ❱  44 │   'l1': lambda x: torch.abs(x).mean(),                               │
│    45 │   'l2': lambda x: (x**2).mean()**0.5,                                │
│    46 │   'mean': lambda x: x.mean(),                                        │
│    47 │   'std': lambda x: x.std(),                                          │
╰──────────────────────────────────────────────────────────────────────────────╯
RuntimeError: mean(): could not infer output dtype. Input dtype must be either a
floating point or complex dtype. Got: Bool

This is caused by passing a Bool Tensor into the get_stat of mup(maybe the attention mask), but the mup library cannot handle it. In addition, we will also encount an error which is caused by passing None to the get_stat.

In order to solve this problem temporarily, I modify the source code in the file coord_check.py in mup like this:
image

This time, coord_check ran successfully, it outputs many jpgs, one for each GPU, jpg from different GPUs looks very similar, so I just show one jpg for each paramerization.

Standard Parameterization:
coord_check_sp 1

muP Parameterization:
coord_check_up 1

The result looks weird, the SP is more horizontal than muP, which is not expected.

Expected behavior
https://github.com/microsoft/mup#coord-check
An expected behavior should looks like the plots in the above link, in which muP is very horizontal, while SP blows up.

Proposed solution
I check the code related to mup, but don't have a proposal yet, I will try to keep checking it. Maybe contributors in the issue(#679) can give some comments? @nsarka @Quentin-Anthony @StellaAthena
Thanks a lot!

Environment (please complete the following information):

  • GPUs: 8 V100 GPU on one node
  • Configs:
    python deepy.py train.py configs/2-7B.yml configs/local_setup.yml
    
    I add mup related configs into the configs/local_setup.yml, and keep completely identical to the instructions in README-MUP.md.
@BaoYu0721 BaoYu0721 added the bug Something isn't working label May 28, 2023
@StellaAthena
Copy link
Member

Thanks for raising this issue. It looks like you’re correct and we broke the implementation at some point.

One thing we really need to start doing (but haven’t been able to do due to manpower limitations) is build out a robust testing suite that verifies new major changes don’t break old features :S

@Quentin-Anthony Quentin-Anthony self-assigned this May 28, 2023
@BaoYu0721
Copy link
Author

Thanks for your reply! I checkout to some other commits, such as the v2.0 release tag and earlier commit when deepspeed_main is merged into main (2b84f9a), and find the plots are similar to the discription above, maybe the bug is introduced even earlier?

@StellaAthena StellaAthena mentioned this issue Jun 2, 2023
25 tasks
@ofivite
Copy link

ofivite commented Aug 2, 2023

I was looking into the muP implementation in gpt-neox to contrast it with the Megatron-LM setup and accidentally found this issue :)

I am thinking, could LR schedule be the cause of the problem? By design it overrides the LR values per group (hence overwrites muP changes), and so the way muP scaling was introduced in AnnealingLR() as rescaling by group["width_mult"] in step() here. But I couldn't find that this key was added neither inside mup optimisers nor in gpt-neox codebase, so I am not sure that width_mult rescaling is applied at all.

Also, width_mult rescaling can be applied only for Adam-like optimisers and matrix-like params (as here), while for SGD the rescaling is with different multipliers, and so should be taken into account.

However, neither I found whether AnnealingLR schedule is actually applied during the training, so that might well be that my comment isn't really relevant to the observed behaviour.

@marcobellagente93
Copy link

I think you are right @ofivite, i don't think the implementation was ever correct since the learning rate wasn't correctly setup from the beginning. After a long and thorough debugging I managed to pass at least the 2 basic sanity checks for mup:

  1. at same width mup doesn't do anything (since all shapes are the same it should coincide with SP)
  2. all the rest being fixed, you only get better by going wider

Issues I have found are:

  • lr not scaled, added the width_mult key to the groups dicts (as stated above)
  • MuAdam being used instead of MuAdamW (maybe irrelevant, need to test more)
  • I have fixed all initializations to be normal with a fixed std. I'm not sure the method is correctly implemented or should be used at all for generic initialization functions, specially those that already scale the std by fan_in since as far as I can see that same std is then scaled again

@marcobellagente93
Copy link

marcobellagente93 commented Aug 29, 2023

Found another bug, neox_args.use_mup is set to false before initializing models, which also sets their use_mup attribute to False and therefore always ignore multipliers

@marcobellagente93
Copy link

And finally there seems to be a bug in the re-initialization of the output layer, after skipping that completely (it should be anyway in the flavour of Table 8) I'm getting these very nice and smooth horizontal lines

coord_check_up 0

@ofivite
Copy link

ofivite commented Aug 31, 2023

@marcobellagente93 Oh yes, now it's indeed nicely flat curves, great ! :)

@marcobellagente93
Copy link

I'll make a PR as soon as I can

@nsarka
Copy link
Contributor

nsarka commented Aug 31, 2023 via email

@marcobellagente93
Copy link

What I mean is that the main training loops with mup enabled does the following:

  1. set neox_args.use_mup to false
  2. initialize model
  3. set neox_args.use_mup back to true

but at step 1 all parameters get initialized with self.use_mup = neox_args.use_mup (which is false) and causes everything else to be wrong (multipliers not used, 1/d attention not used, ...)

@nsarka
Copy link
Contributor

nsarka commented Aug 31, 2023

This behavior is expected--the weights are reinitialized using

mup_weights_reinit(neox_args, model)

Or do you mean this function does not get called?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

6 participants