Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]: HybridParallelOptimizer holds unsharded model parameters after sharding #5539

Open
insujang opened this issue Mar 31, 2024 · 9 comments 路 Fixed by #5545
Open

[BUG]: HybridParallelOptimizer holds unsharded model parameters after sharding #5539

insujang opened this issue Mar 31, 2024 · 9 comments 路 Fixed by #5545
Labels
bug Something isn't working

Comments

@insujang
Copy link
Contributor

馃悰 Describe the bug

When using tensor parallelism, model parameters are sharded across GPUs to reduce its memory consumption and parallel execution.
However, the optimizer still holds unsharded model parameters, preventing the old unsharded parameters from being released, taking more memory.

Example code: (adopted from examples/language/gpt2/hybridparallelism/finetune.py)

colossalai.launch_from_torch(config={})
plugin = HybridParallelPlugin(tp_size=4, pp_size=1)
optimizer = Adam(model.parameters())
# initialize dataloader

model, optimizer, *_ = booster.booster(model, optimizer, ...)
> model.module.transformer.wte.weight
Parameter containing:
tensor([[-0.1101, -0.0393, ...]], device='cuda:0', dtype=torch.float16, requires_grad=True)

> model.module.transformer.wte.weight.shape
torch.Size([12565, 768])

> optimizer.param_groups[0]["params"][0]
Parameter containing:
tensor([[-0.1101, -0.0393, ...]], device='cuda:0', requires_grad=True)

> optimizer.param_groups[0]["params"[0].shape
torch.Size([50257, 768])

This also affects MixedPrecisionOptimizer.master_to_working_map and MixedPrecisionOptimizer.working_to_master_map:

# model.module.transformer.wte.weight is supposed to be in a working parameter
> model.module.transformer.wte.weight.shape
torch.Size([12565, 768])
> id(model.module.transformer.wte.weight)
139684649437120

# First working parameter in map does not refer to this
> list(iter(optimizer.master_to_working_map))[0].shape
torch.Size([50257, 768])
> id(list(iter(optimizer.master_to_working_map))[0])
139693862695728

Because of this it seems only a portion of parameters (ie. unsharded ones) only trained, as MixedPrecisionOptimizer.step() skips sharded parameters as gradients are not stored in mismatched unsharded parameters:

if working_param.grad is not None:
p.grad = working_param.grad.data.float()
working_param.grad = None

Environment

PyTorch 2.2.1 / CUDA 12.1

@insujang insujang added the bug Something isn't working label Mar 31, 2024
@Edenzzzz
Copy link
Contributor

Edenzzzz commented Apr 1, 2024

Hi, thanks for the issue.
I reproduced the bug using this script
finetune.zip
This might be due to some unexpected model movement without ZeRO. Mostly ZeRO is used and the params are sharded in-place. I'm looking into this.

@Edenzzzz
Copy link
Contributor

Edenzzzz commented Apr 3, 2024

This happens only when sequence parallel is on and ZeRO is off. We are rebuilding the seq parallel API with ring attention etc., so I've set it to False in enable_all_optimization as a quick fix.

@insujang
Copy link
Contributor Author

insujang commented Apr 3, 2024

@Edenzzzz , thank you for your time looking into this issue. I am not sure if this fix works. I tested with
enable_all_optimization=False, enable_sequence_parallelism=False, and enable_sequence_overlap=False, still the same problem happens from my side. Could you check again?

Edit: this is my plugin configuration used:

plugin = HybridParallelPlugin(
            tp_size=4,
            pp_size=1,
            num_microbatches=None,
            microbatch_size=1,
            enable_all_optimization=False,
            enable_sequence_parallelism=False,
            enable_sequence_overlap=False,
            zero_stage=0,
            precision="fp16",
            initial_scale=1,
        )

@Edenzzzz
Copy link
Contributor

Edenzzzz commented Apr 6, 2024

This bug seems specific to a minority of TP plans. Will take another look
image

@insujang
Copy link
Contributor Author

insujang commented Apr 8, 2024

Looks like preprocess in each policy might be the reason:

def preprocess(self):
# reshape the embedding layer
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
# TODO:
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model

def preprocess(self):
# reshape the embedding layer
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model

Although all policies have the same resize logic, each model has different default vocab embedding size, so only bert and gpt2 in your tests need resizing embedding, which create a new one and fail:

from transformers import AutoConfig

def test_vocab_size_divisible_to_tp_size(model_name: str, tp_size: int):
    config = AutoConfig.from_pretrained(model_name)
    vocab_size = config.vocab_size

    print(f"model {model_name} vocab_size: {vocab_size}. Need to resize embeddings for tp degree {tp_size}? {vocab_size % tp_size != 0}")

test_vocab_size_divisible_to_tp_size("gpt2", 8)
test_vocab_size_divisible_to_tp_size("bert-base-uncased", 8)
test_vocab_size_divisible_to_tp_size("facebook/opt-125m", 8)
test_vocab_size_divisible_to_tp_size("tiiuae/falcon-rw-1b", 8)
model gpt2 vocab_size: 50257. Need to resize embeddings for tp degree 8? True
model bert-base-uncased vocab_size: 30522. Need to resize embeddings for tp degree 8? True
model facebook/opt-125m vocab_size: 50272. Need to resize embeddings for tp degree 8? False
model tiiuae/falcon-rw-1b vocab_size: 50304. Need to resize embeddings for tp degree 8? False

It creates a complete new nn.Embedding and therefore their ID becomes different:
https://github.com/huggingface/transformers/blob/76fa17c1663a0efeca7208c20579833365584889/src/transformers/modeling_utils.py#L2017-L2028

# Before calling `preprocess()` on gpt2:
id(model.transformer.wte.weight)
140670116084640
model.transformer.wte
Embedding(50257, 768)

# After calling `preprocess()` on gpt2:
id(model.transformer.wte.weight)
140670118343072
model.transformer.wte
Embedding(50260, 768)

@insujang
Copy link
Contributor Author

insujang commented Apr 8, 2024

A quick potential patch is not to use HF's resize_token_embeddings and use nn.functional.pad to resize tensor while avoiding recreation of nn.Embedding (not sure if there are other attributes that should also be modified):

def resize_token_embedding_inplace(num_new_tokens: int, embedding: nn.Embedding):
    # In-place resize of the token embeddings
    embedding.num_embeddings = new_num_tokens
    embedding.weight.data = nn.functional.pad(
        embedding.weight.data,
        (0, 0, 0, new_num_tokens - embedding.weight.size(0)),
        "constant",
        0,
    )
 
# In policy
def preprocess(self):
    # reshape the embedding layer
    r"""
    Reshape the Embedding layer to make the embedding dimension divisible by world_size
    """
    if self.shard_config.enable_tensor_parallelism:
        vocab_size = self.model.config.vocab_size
        world_size = self.shard_config.tensor_parallel_size
        if vocab_size % world_size != 0:
            new_vocab_size = vocab_size + world_size - vocab_size % world_size

            resize_token_embedding_inplace(new_vocab_size, self.model.get_input_embeddings())
            # self.model.resize_token_embeddings(new_vocab_size)

    return self.model

@Edenzzzz Could you please check if it works? Thanks

@insujang
Copy link
Contributor Author

insujang commented Apr 8, 2024

Maybe it is related to #5489 ?

@Edenzzzz
Copy link
Contributor

Edenzzzz commented Apr 8, 2024

A quick potential patch is not to use HF's resize_token_embeddings and use nn.functional.pad to resize tensor while avoiding recreation of nn.Embedding (not sure if there are other attributes that should also be modified):

def resize_token_embedding_inplace(num_new_tokens: int, embedding: nn.Embedding):
    # In-place resize of the token embeddings
    embedding.num_embeddings = new_num_tokens
    embedding.weight.data = nn.functional.pad(
        embedding.weight.data,
        (0, 0, 0, new_num_tokens - embedding.weight.size(0)),
        "constant",
        0,
    )
 
# In policy
def preprocess(self):
    # reshape the embedding layer
    r"""
    Reshape the Embedding layer to make the embedding dimension divisible by world_size
    """
    if self.shard_config.enable_tensor_parallelism:
        vocab_size = self.model.config.vocab_size
        world_size = self.shard_config.tensor_parallel_size
        if vocab_size % world_size != 0:
            new_vocab_size = vocab_size + world_size - vocab_size % world_size

            resize_token_embedding_inplace(new_vocab_size, self.model.get_input_embeddings())
            # self.model.resize_token_embeddings(new_vocab_size)

    return self.model

@Edenzzzz Could you please check if it works? Thanks

Thanks for the nice catch! This worked for both gpt2 and bert. Yes some fix appears to be in progress. Will touch base with them tomorrow.

@insujang
Copy link
Contributor Author

insujang commented Apr 8, 2024

Nice to hear that the fix will be merged very soon. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants