New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix unnessary vram usage while injecting fused attn #453
base: main
Are you sure you want to change the base?
Fix unnessary vram usage while injecting fused attn #453
Conversation
In the current implentmentation, the old q/k/v_proj module is not released until the end of the entire loop, which roughly doubles the amount of peak VRAM usage of qkv_layer. This commit fix this issue. The fused GPTJ injection has been modified yet.
Fix the same issue of fused_llama_attn.
Oh wow, that sounds interesting. Do you have any numbers on how much that might reduce VRAM usage by during a quant? I will test it myself when I have a chance |
I think it's not related to the quantization stage. It only reduces the peak VRAM cost during model loading with the Here is my peak memory usage on a RTX 4060ti 16G:
|
Oh, sorry, misunderstood That's an excellent inference improvement though! Very nice. |
@lszxb Thank you for this PR! Before merging, I'd like to make sure to understand what is up here. I can't reproduce the issue about from transformers import AutoModel
import torch
import torch.nn as nn
from transformers.models.bert.modeling_bert import BertSelfAttention
import time
with torch.device("cuda"):
model = AutoModel.from_pretrained("bert-base-uncased")
def recurse_setattr(module, name, value):
"""A function to recursively set attributes to a module."""
if "." not in name:
setattr(module, name, value)
else:
name, rest = name.split(".", 1)
recurse_setattr(getattr(module, name), rest, value)
class MyBertAttention(nn.Module):
def __init__(self, qkv_proj):
super().__init__()
self.qkv_proj = qkv_proj
def forward(self, x):
return self.qkv_proj(x)
for name, m in model.named_modules():
if not isinstance(m, BertSelfAttention):
continue
print("override")
qkv_layer_weight = torch.cat(
[
m.query.weight,
m.key.weight,
m.value.weight,
]
)
bias = m.query.bias is not None
if bias:
qkv_layer_bias = torch.cat(
[
m.query.bias,
m.key.bias,
m.value.bias,
]
)
with torch.device("meta"):
qkv_proj = nn.Linear(m.query.weight.shape[1], m.query.weight.shape[0] * 3, bias=bias)
qkv_proj.weight = torch.nn.Parameter(qkv_layer_weight)
if bias:
qkv_proj.bias = torch.nn.Parameter(qkv_layer_bias)
recurse_setattr(model, name, MyBertAttention(qkv_proj))
time.sleep(0.2) Here the overriding does not seem to duplicate memory. |
@fxmarty It's about the def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True):
if memo is None:
memo = set()
if self not in memo:
if remove_duplicate:
memo.add(self)
yield prefix, self
for name, module in self._modules.items():
if module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
yield from module.named_modules(memo, submodule_prefix, remove_duplicate) It seems that there is a problem with your testing script in the model loading part. On my machine, I modified the script and there is a difference of the peak memory usage: from transformers import AutoModel
import torch
import torch.nn as nn
from transformers.models.bert.modeling_bert import BertSelfAttention
import time
with torch.device("cuda"):
model = AutoModel.from_pretrained("bert-base-uncased")
print(torch.cuda.max_memory_allocated())
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
def recurse_setattr(module, name, value):
"""A function to recursively set attributes to a module."""
if "." not in name:
setattr(module, name, value)
else:
name, rest = name.split(".", 1)
recurse_setattr(getattr(module, name), rest, value)
class MyBertAttention(nn.Module):
def __init__(self, qkv_proj):
super().__init__()
self.qkv_proj = qkv_proj
def forward(self, x):
return self.qkv_proj(x)
for name, m in model.named_modules(remove_duplicate=False): # change remove_duplicate to see a difference
if not isinstance(m, BertSelfAttention):
continue
print("override")
qkv_layer_weight = torch.cat(
[
m.query.weight,
m.key.weight,
m.value.weight,
]
)
bias = m.query.bias is not None
if bias:
qkv_layer_bias = torch.cat(
[
m.query.bias,
m.key.bias,
m.value.bias,
]
)
with torch.device("meta"):
qkv_proj = nn.Linear(m.query.weight.shape[1], m.query.weight.shape[0] * 3, bias=bias)
qkv_proj.weight = torch.nn.Parameter(qkv_layer_weight)
if bias:
qkv_proj.bias = torch.nn.Parameter(qkv_layer_bias)
recurse_setattr(model, name, MyBertAttention(qkv_proj))
time.sleep(0.2)
print(torch.cuda.max_memory_allocated()) # here is 524114944 vs 453234688 on my machine |
In PyTorch,
Module.named_modules()
may create a generator. During iterating, the deduplicate memo is created and keep the reference of all submodules, which prevents releasing the oldself_attn
andq/k/v_proj
until the end of the entire loop, and roughly doubles the amount of peak VRAM usage of qkv_layer.This PR fixes this issue by first recording all submodules with a list. Then remove its reference from the list once it's accessed.
BTW, using
Module.named_modules(remove_duplicate=False)
may also fix this issue, but I think it's much more dependent on the internal implementation of PyTorch.