Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Inconsistent loss between overlap_comm=true and overlap_comm=false #1004

Open
0x6b64 opened this issue Jul 27, 2023 · 4 comments
Open
Labels
bug Something isn't working

Comments

@0x6b64
Copy link

0x6b64 commented Jul 27, 2023

Describe the bug
DeepSpeed provides a ZeRO configuration property overlap_comm which according to the documentation Attempts to overlap the reduction of the gradients with backward computation (Ref: https://www.deepspeed.ai/docs/config-json/). I'm noticing that the lm loss is different when overlap_comm=true compared to when overlap_comm=false

To Reproduce
I'm running training with enwik8_text_document corpus. Here is my configuration.

{
  "pipe-parallel-size": 1,
  "model-parallel-size": 8,
  "num-layers": 80,
  "hidden-size": 8192,
  "num-attention-heads": 64,
  "seq-length": 4096,
  "max-position-embeddings": 4096,
  "norm": "layernorm",
  "pos-emb": "rotary",
  "rotary_pct": 0.25,
  "no-weight-tying": true,
  "gpt_j_residual": true,
  "output_layer_parallelism": "column",
  "attention-config": [
    [
      [
        "flash"
      ],
      80
    ]
  ],
  "scaled-upper-triang-masked-softmax-fusion": true,
  "bias-gelu-fusion": true,
  "optimizer": {
    "type": "Adam",
    "params": {
      "lr": 0.0006,
      "betas": [
        0.9,
        0.95
      ],
      "eps": 1.0e-6
    }
  },
  "min_lr": 0.00006,
  "zero_optimization": {
    "stage": 2,
    "allgather_partitions": true,
    "allgather_bucket_size": 500000000,
    "overlap_comm": true,
    "reduce_scatter": true,
    "reduce_bucket_size": 500000000,
    "contiguous_gradients": true,
    "cpu_offload": false
  },
  "global_num_gpus": 512,
  "train_batch_size": 512,
  "train_micro_batch_size_per_gpu": 8,
  "gradient_accumulation_steps": 1,
  "data-impl": "mmap",
  "checkpoint-activations": true,
  "checkpoint-num-layers": 1,
  "partition-activations": true,
  "synchronize-each-layer": true,
  "gradient_clipping": 1.0,
  "weight-decay": 0.2,
  "hidden-dropout": 0,
  "attention-dropout": 0,
  "fp16": {
    "fp16": false,
    "enabled": true,
    "loss_scale": 0,
    "loss_scale_window": 1000,
    "initial_scale_power": 12,
    "hysteresis": 2,
    "min_loss_scale": 1
  },
  "flops_profiler": {
    "enabled": false,
    "profile_step": 2,
    "module_depth": -1,
    "top_modules": 1,
    "detailed": false,
    "output_file": null
  },
  "train-iters": 15,
  "lr-decay-iters": 360000,
  "distributed-backend": "nccl",
  "lr-decay-style": "cosine",
  "warmup": 0.01,
  "checkpoint-factor": 1000,
  "eval-interval": 4000,
  "eval-iters": 109,
  "log-interval": 1,
  "steps_per_print": 1,
  "wall_clock_breakdown": true,
  "use_wandb": false,
  "data-path": "enwik8_text_document",
  "vocab-file": "data/gpt2-vocab.json",
  "merge-file": "data/gpt2-merges.txt",
  "launcher": "openmpi",
  "deepspeed_mpi": true
}

Here are the observed losses

overlap_comm=false    | overlap_comm=true
lm_loss: 1.247658E+01 | 1.247658E+01
lm_loss: 1.247326E+01 | 1.247326E+01
lm_loss: 1.203598E+01 | 1.203597E+01
lm_loss: 1.033948E+01 | 1.033946E+01
lm_loss: 1.536938E+01 | 1.536933E+01
lm_loss: 1.704729E+01 | 1.704725E+01
lm_loss: 1.572084E+01 | 1.572070E+01
lm_loss: 1.475276E+01 | 1.475274E+01
lm_loss: 1.341277E+01 | 1.341282E+01
lm_loss: 1.274201E+01 | 1.274182E+01
lm_loss: 1.182359E+01 | 1.182348E+01
lm_loss: 1.106408E+01 | 1.106395E+01
lm_loss: 1.068731E+01 | 1.068723E+01
lm_loss: 1.115071E+01 | 1.114968E+01
lm_loss: 1.079878E+01 | 1.079566E+01

One may argue that the losses are "close". However, the expectation is that the computation should be exact. Given that only overlap_comm setting has changed makes me wonder if overlap_comm implicitly introduces some sort of a data copy race condition and the losses diverge more so overtime.

Expected behavior
The expectation is that the losses should be exact. DeepSpeed doesn't have the contract that with overlap_comm the computation is only loosely correct.

Proposed solution
I've spent some time looking at the nsys profiles for this training, but nothing immediately stands out. Using Stage2, there is some data movement which happens across different buffers.

In stage_1_and_2.py:905 the stream is set to reduction_stream which waits for default stream after which a sequence of reduce operations are scheduled.

After average_tensor completes there is a call to copy_grads_in_partition which copies the reduced values from the ipg_buffer to the newly allocated buffer to hold the gradients.

With overlap_comm is enabled, the cudaMemCopyAsync op is not synchronized with the completion of the reduce operation && hence the data that is copied over (thinking that it is the reduced result) may or may not have been reduced. This is happening because the collective wait semantic is to only synchronize the completion of the Reduce Op on the collective stream with the default stream. When overlap_comm is enabled, the reduction_stream is used && the wait() operation will not synchronize with this. This can be confirmed from the implementation in ProcessGroupNCCL as well as PyTorch documentation on the use of the wait semantic with asynchronous collectives.

However, if what I'm saying here was the case, the whole overlap_comm implementation is incorrect. I'll create a similar issue with DeepSpeed as well. But wanted to bring this up here incase anyone else has noticed different loss dynamic when overlap_comm=false is toggled.

Screenshots
N/A

Environment (please complete the following information):

  • GPUs: on 64 nodes on AWS p4d instances.
  • Configs: Attached above.

Additional context

  • Using DeepSpeed 0.9.5
@0x6b64 0x6b64 added the bug Something isn't working label Jul 27, 2023
@0x6b64
Copy link
Author

0x6b64 commented Jul 27, 2023

@clumsy

@dashstander
Copy link
Contributor

@0x6b64 have you looked into this any more? My first reaction is indeed that those are pretty small differences and I wonder if they may just come from non-deterministic pytorch ops. One thing worth checking would be to try and make a Deepspeed model that (somehow) only uses deterministic ops and re-running these tests.

@0x6b64
Copy link
Author

0x6b64 commented Sep 15, 2023

@dashstander - thanks for the response. Unfortunately, I haven't had time to dive into this yet. If I have any updates, I'll report them here.

@dashstander
Copy link
Contributor

I brought this up with @Quentin-Anthony and he was skeptical of my "non-deterministic ops" theory, and he'd know much better than I, so this is definitely a bit of a mystery!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants