Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fine-tuning gpt-neox on 8 A100s #892

Open
rajhans opened this issue Apr 20, 2023 · 6 comments
Open

Fine-tuning gpt-neox on 8 A100s #892

rajhans opened this issue Apr 20, 2023 · 6 comments
Labels
feature request New feature or request

Comments

@rajhans
Copy link

rajhans commented Apr 20, 2023

Hi folks,
Would you have any recommendations for how can we fine-tune gpt-neox on 8 A100 (80G) machines? We have a Deepspeed setup figured out but it OOMs at zero2/zero3... does not OOM when using CPU offloading but that is incredibly slow.

I am happy to run your fork of deepspeed. However, your config here mentions 96 A100 machines which is far more than what we have.

Some alternatives we are considering are just doing LoRA, which can work. I am just exploring if it is possible to fine-tune the full model.

Any insights are greatly appreciated.

@rajhans rajhans added the feature request New feature or request label Apr 20, 2023
@StellaAthena
Copy link
Member

StellaAthena commented Apr 20, 2023

8x 80 GB = 640 GB VRAM. A 20B model takes up approximately 360 GB of VRAM to run training so it should fit. Note that our GPUs when training the model were 40 GB ones, not 80 GB ones.

Just to make sure nothing is going wrong in set-up, can you try launching pretraining instead? Does that also OOM? Or is it only when launching finetuning?

The other thing that stands out is that we generally don’t use ZeRO-3, instead opting for ZeRO-1 + 3D parallelism. Given the numbers I don’t think that’s what’s going on, but I just wanted to make you aware of that.

@Quentin-Anthony
Copy link
Member

8x 80 GB = 640 GB VRAM. A 20B model takes up approximately 360 GB of VRAM to run training so it should fit. Note that our GPUs when training the model were 40 GB ones, not 80 GB ones.

Just to make sure nothing is going wrong in set-up, can you try launching pretraining instead? Does that also OOM? Or is it only when launching finetuning?

The other thing that stands out is that we generally don’t use ZeRO-3, instead opting for ZeRO-1 + 3D parallelism. Given the numbers I don’t think that’s what’s going on, but I just wanted to make you aware of that.

+1 to @StellaAthena's suggestion to first try pretraining a 20B config with 3D parallelism. This way you can debug your config in a vacuum. Start by increasing model parallelism to 8 and setting zero stage 1:

  "pipe-parallel-size": 1,
  "model-parallel-size": 8,

and

"zero_optimization": {
    "stage": 1,
...

As a sanity check. Other things to ensure you're doing to use less memory (in decreasing order):

  • Turn on activation checkpointing ("checkpoint-activations": true,) and activation partitioning ("partition-activations": true,)
  • Use fp16 training ("fp16": true, within the "fp16" dict)
  • Reduce the per-GPU batch size to 1 ("train_micro_batch_size_per_gpu": 1,)
  • Use flash attention ("attention_config": [[["flash"], x]] where x is the num-layers config option, 44 if you're using our 20B config.
  • Turn on our fused kernels "scaled-upper-triang-masked-softmax-fusion": true, and "bias-gelu-fusion": true,

Once you get a pre-training config running with these options, apply them to your finetuning config. If nothing works, send your pre-training config and we'll take a look :)

@rajhans
Copy link
Author

rajhans commented Apr 21, 2023

Thanks a lot, folks. We had not explored 3d parallelism and flash attention before. Although we did do fp16, smaller microbatches, and activation checkpointing.

One high-level question (just for my own mental model of the LLM infra landscape): many of these params/constructs like 3d parallelism and flash attention are not available in Deepspeed and you added the support for it in Deeperspeed. Correct?

We will give this a spin. Thanks again -- you're truly doing great work.

@Quentin-Anthony
Copy link
Member

Thanks a lot, folks. We had not explored 3d parallelism and flash attention before. Although we did do fp16, smaller microbatches, and activation checkpointing.

One high-level question (just for my own mental model of the LLM infra landscape): many of these params/constructs like 3d parallelism and flash attention are not available in Deepspeed and you added the support for it in Deeperspeed. Correct?

We will give this a spin. Thanks again -- you're truly doing great work.

Features like tensor parallelism and flash attention were added within gpt-neox. Our DeeperSpeed is now very similar to upstream DeepSpeed. The two are (for now) functionally equivalent, and the only difference is some bug fixes.

@cateto
Copy link

cateto commented May 23, 2023

@StellaAthena may i ask you why generally don’t use ZeRO-3, instead opting for ZeRO-1 + 3D parallelism?

@EasonLi24
Copy link

@StellaAthena may i ask you why generally don’t use ZeRO-3, instead opting for ZeRO-1 + 3D parallelism?

@StellaAthena I'm also very interested in that,can you explain it to us, pls~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants