Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

20B pretrained model inference OOM on 8xA100 40GB #901

Open
Mutinifni opened this issue Apr 23, 2023 · 3 comments
Open

20B pretrained model inference OOM on 8xA100 40GB #901

Mutinifni opened this issue Apr 23, 2023 · 3 comments
Assignees
Labels
bug Something isn't working good first issue Good for newcomers

Comments

@Mutinifni
Copy link

Mutinifni commented Apr 23, 2023

Describe the bug

Inference using the 20B pretrained model from README with slim weights and 20B.yml config runs out of memory on 8xA100 40GB GPUs. I tried varying pipe-parallel-size and model-parallel-size from [1, 2, 4] but none worked.

torch.cuda.OutOfMemoryError: CUDA out of memory. 
Tried to allocate 10.13 GiB (GPU 5; 39.43 GiB total capacity; 35.45 GiB already allocated; 2.39 GiB free; 35.45 GiB reserved in total by PyTorch) 
If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  
See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

I also ran into an OOM with the huggingface 20B model, based on the code in this repo. (from #782)

RuntimeError: CUDA out of memory. 
Tried to allocate 592.00 MiB (GPU 0; 39.43 GiB total capacity; 38.05 GiB already allocated; 474.25 MiB free; 38.14 GiB reserved in total by PyTorch) 
If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  
See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

To Reproduce

  • Download pretrained weights
  • Copy over 20B.yml config file from the pretrained config download; optionally modify params
  • Run:
    python ./deepy.py generate.py -d configs 20B.yml local_setup.yml text_generation.yml

Expected behavior
Inference should not OOM (training OOM is expected).

Proposed solution
Not sure.

Environment (please complete the following information):

  • GPUs: 8xA100 40GB SXM
  • Configs: 20B.yml.

Should I try any different configurations options for these to work? Thanks!

@Mutinifni Mutinifni added the bug Something isn't working label Apr 23, 2023
@Mutinifni Mutinifni changed the title 20B pretrained model inference OOM on 8xA100 40GB SXM 20B pretrained model inference OOM on 8xA100 40GB Apr 23, 2023
@StellaAthena StellaAthena added the good first issue Good for newcomers label Apr 30, 2023
@satpalsr
Copy link
Contributor

satpalsr commented May 1, 2023

Hey @Mutinifni
While this in work you can use Deepspeed MII or accelerate.

Here's a snippet for deepspeed

import mii
mii_configs = {"tensor_parallel": 2, "dtype": "fp16", "load_with_sys_mem": True}
mii.deploy(task="text-generation",
           model="EleutherAI/gpt-neox-20b",
           deployment_name="gpt-neox-20b-deploy",
           mii_config=mii_configs)

generator = mii.mii_query_handle("gpt-neox-20b-deploy")
result = generator.query({"query": ["DeepSpeed is", "Seattle is"]}, do_sample=True, max_new_tokens=30)

# terminate if you no longer want to infer
mii.terminate("gpt-neox-20b-deploy")

and for accelerate something like

import torch
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoModelForCausalLM

from accelerate import dispatch_model, infer_auto_device_map
from accelerate.utils import get_balanced_memory

tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
model = AutoModelForCausalLM.from_pretrained('EleutherAI/gpt-neox-20b')

max_memory = get_balanced_memory(
    model,
    max_memory=None,
    no_split_module_classes=["GPTNeoXLayer"],
    dtype='float16',
    low_zero=False,
)

device_map = infer_auto_device_map(
    model, 
    max_memory=max_memory,
    no_split_module_classes=["GPTNeoXLayer"], 
    dtype='float16'
)

model = dispatch_model(model, device_map=device_map)

@Mutinifni
Copy link
Author

Thanks @satpalsr! DeepSpeed MII worked for me (with just 2 GPUs).

I would like to ask a follow-up question to understand this a little bit better. Based on the DeepSpeed MII latency/cost analysis, it looks like DeepSpeed MII performs much better than the baseline (presumably huggingface transformers), so is there any reason to prefer the huggingface model for deployment? Do DeepSpeed MII or accelerate underperform with larger GPU deployments, or are they strictly better?

@StellaAthena
Copy link
Member

Thanks @satpalsr! DeepSpeed MII worked for me (with just 2 GPUs).

I would like to ask a follow-up question to understand this a little bit better. Based on the DeepSpeed MII latency/cost analysis, it looks like DeepSpeed MII performs much better than the baseline (presumably huggingface transformers), so is there any reason to prefer the huggingface model for deployment? Do DeepSpeed MII or accelerate underperform with larger GPU deployments, or are they strictly better?

Our understanding is that it's strictly better. We're currently working on replacing our current inference backend with it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

4 participants