Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unnessary vram usage while injecting fused attn #453

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

lszxb
Copy link

@lszxb lszxb commented Nov 28, 2023

In PyTorch, Module.named_modules() may create a generator. During iterating, the deduplicate memo is created and keep the reference of all submodules, which prevents releasing the old self_attn and q/k/v_proj until the end of the entire loop, and roughly doubles the amount of peak VRAM usage of qkv_layer.

This PR fixes this issue by first recording all submodules with a list. Then remove its reference from the list once it's accessed.

BTW, using Module.named_modules(remove_duplicate=False) may also fix this issue, but I think it's much more dependent on the internal implementation of PyTorch.

lszxb added 2 commits November 28, 2023 16:41
In the current implentmentation, the old q/k/v_proj module is not
released until the end of the entire loop, which roughly doubles the
amount of peak VRAM usage of qkv_layer. This commit fix this issue.
The fused GPTJ injection has been modified yet.
Fix the same issue of fused_llama_attn.
@TheBloke
Copy link
Contributor

Oh wow, that sounds interesting. Do you have any numbers on how much that might reduce VRAM usage by during a quant?

I will test it myself when I have a chance

@lszxb
Copy link
Author

lszxb commented Nov 28, 2023

Oh wow, that sounds interesting. Do you have any numbers on how much that might reduce VRAM usage by during a quant?

I will test it myself when I have a chance

I think it's not related to the quantization stage. It only reduces the peak VRAM cost during model loading with the inject_fused_attention=True in from_quantized (only LLaMA and GPTJ model are supported currently) on inference.

Here is my peak memory usage on a RTX 4060ti 16G:

Model Peak VRAM w this PR Peak VRAM w/o this PR
TheBloke_Orca-2-13B-GPTQ_gptq-4bit-32g-actorder_True 8240750592 10014685184
TheBloke_Orca-2-13B-GPTQ_gptq-8bit-128g-actorder_True 14052081664 OOM

@TheBloke
Copy link
Contributor

Oh, sorry, misunderstood

That's an excellent inference improvement though! Very nice.

@fxmarty
Copy link
Collaborator

fxmarty commented Dec 7, 2023

@lszxb Thank you for this PR! Before merging, I'd like to make sure to understand what is up here. I can't reproduce the issue about name_modules:

from transformers import AutoModel
import torch
import torch.nn as nn
from transformers.models.bert.modeling_bert import BertSelfAttention
import time

with torch.device("cuda"):
    model = AutoModel.from_pretrained("bert-base-uncased")

def recurse_setattr(module, name, value):
    """A function to recursively set attributes to a module."""
    if "." not in name:
        setattr(module, name, value)
    else:
        name, rest = name.split(".", 1)
        recurse_setattr(getattr(module, name), rest, value)

class MyBertAttention(nn.Module):
    def __init__(self, qkv_proj):
        super().__init__()
        self.qkv_proj = qkv_proj
    
    def forward(self, x):
        return self.qkv_proj(x)

for name, m in model.named_modules():
    if not isinstance(m, BertSelfAttention):
        continue
    
    print("override")
    qkv_layer_weight = torch.cat(
        [
            m.query.weight,
            m.key.weight,
            m.value.weight,
        ]
    )
    
    bias = m.query.bias is not None
    if bias:
        qkv_layer_bias = torch.cat(
            [
                m.query.bias,
                m.key.bias,
                m.value.bias,
            ]
        )

    with torch.device("meta"):
        qkv_proj = nn.Linear(m.query.weight.shape[1], m.query.weight.shape[0] * 3, bias=bias)

    qkv_proj.weight = torch.nn.Parameter(qkv_layer_weight)
    if bias:
        qkv_proj.bias = torch.nn.Parameter(qkv_layer_bias)
        
    
    recurse_setattr(model, name, MyBertAttention(qkv_proj))
    
    time.sleep(0.2)

Here the overriding does not seem to duplicate memory.

@lszxb
Copy link
Author

lszxb commented Dec 8, 2023

@fxmarty It's about the remove_duplicate option of named_modules, which is added in PyTorch 1.9 and default to True. In the current implementation, there is a memo set that keep the reference of all the previous module until the loop is end if remove_duplicate is set to True.

def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True):
    if memo is None:
        memo = set()
    if self not in memo:
        if remove_duplicate:
            memo.add(self)
        yield prefix, self
        for name, module in self._modules.items():
            if module is None:
                continue
            submodule_prefix = prefix + ('.' if prefix else '') + name
            yield from module.named_modules(memo, submodule_prefix, remove_duplicate)

It seems that there is a problem with your testing script in the model loading part. On my machine, AutoModel.from_pretrained uses too much memory(for some unknown reason, which is not the case when we are loading LLM), and the duplicated memory usage is not able to be observed because of caching.

I modified the script and there is a difference of the peak memory usage:

from transformers import AutoModel
import torch
import torch.nn as nn
from transformers.models.bert.modeling_bert import BertSelfAttention
import time

with torch.device("cuda"):
    model = AutoModel.from_pretrained("bert-base-uncased")

print(torch.cuda.max_memory_allocated())
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

def recurse_setattr(module, name, value):
    """A function to recursively set attributes to a module."""
    if "." not in name:
        setattr(module, name, value)
    else:
        name, rest = name.split(".", 1)
        recurse_setattr(getattr(module, name), rest, value)

class MyBertAttention(nn.Module):
    def __init__(self, qkv_proj):
        super().__init__()
        self.qkv_proj = qkv_proj
    
    def forward(self, x):
        return self.qkv_proj(x)

for name, m in model.named_modules(remove_duplicate=False): # change remove_duplicate to see a difference
    if not isinstance(m, BertSelfAttention):
        continue
    
    print("override")
    qkv_layer_weight = torch.cat(
        [
            m.query.weight,
            m.key.weight,
            m.value.weight,
        ]
    )
    
    bias = m.query.bias is not None
    if bias:
        qkv_layer_bias = torch.cat(
            [
                m.query.bias,
                m.key.bias,
                m.value.bias,
            ]
        )

    with torch.device("meta"):
        qkv_proj = nn.Linear(m.query.weight.shape[1], m.query.weight.shape[0] * 3, bias=bias)

    qkv_proj.weight = torch.nn.Parameter(qkv_layer_weight)
    if bias:
        qkv_proj.bias = torch.nn.Parameter(qkv_layer_bias)
        
    
    recurse_setattr(model, name, MyBertAttention(qkv_proj))
    
    time.sleep(0.2)

print(torch.cuda.max_memory_allocated()) # here is 524114944 vs 453234688 on my machine

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants