Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LongLoRA implementation #8341

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from

Conversation

bruicecode
Copy link

PR types

Others

PR changes

Others

Description

  1. 修改了llm/data.py 中处理red pajama中这种只有text的example方法
  2. 在LlamaAttention forward使用scaled_dot_product_attention前对qkv进行shift操作,计算后对attn_output进行shift操作
  3. 在llm/finetune_generation.py修改了rope scaling,保证模型能够处理2048以上长度的序列
  4. 在llm/finetune_generation.py设置了norm embed层可训练

Copy link

paddle-bot bot commented Apr 28, 2024

Thanks for your contribution!

@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

attn_output = paddle.reshape(attn_output, (bsz, q_len, self.num_heads, self.head_dim))
attn_output[:, :, self.num_heads//2:] = paddle.roll(attn_output[:, :, self.num_heads//2:],
shifts=group_size//2, axis=1)
attn_output = paddle.reshape(attn_output,(bsz, q_len, self.hidden_size))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shift back的逻辑有问题

# attn_output [bsz*num_group, group_size, head_dim * num_heads]->[bsz, q_len,  num_heads,  head_dim]
attn_output = paddle.reshape(attn_output, (bsz, num_group*group_size, self.num_heads,  self.head_dim))
attn_output[:, :, self.num_heads//2:] = paddle.roll(attn_output[:, :, self.num_heads//2:], shifts=group_size//2, axis=1)
attn_output = paddle.reshape(attn_output, (bsz, num_group*group_size, self.hidden_size))

# print(attention_mask)
num_group = q_len // group_size
attention_mask = create_attention_mask((bsz,q_len),dtype="float16")
attention_mask = attention_mask[:, :, :group_size, :group_size]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attention_mask的逻辑需要额外处理

if self.config.shift:
            assert len(attention_mask.shape)==2, "attention_mask should be 2-dim for shift"
            bs = input_ids.shape[0]
            attention_mask.reshape([bs*self.config.group_num, -1])
            expanded_attn_mask = _expand_2d_mask(attention_mask, self.config.dtype, tgt_length=attention_mask.shape[-1])
            # For decoding phase in generation, seq_length = 1, we don't need to add causal mask
            if attention_mask.shape[-1] > 1:
                past_key_values_lengt = paddle.shape(past_key_values[0][0])[1] if past_key_values is None else 0
                combined_attention_mask = _make_causal_mask(
                    attention_mask.shape, past_key_values_length=past_key_values_lengt
                )
                if get_env_device() == "npu":
                    expanded_attn_mask = expanded_attn_mask.astype("bool")
                    combined_attention_mask = combined_attention_mask.astype("bool")
                attention_mask = expanded_attn_mask & combined_attention_mask

@@ -45,16 +45,53 @@ def get_convert_example(model):
if base_model_prefix == "chatglm":
return convert_example_chatglm
elif base_model_prefix in ["chatglm_v2", "llama", "bloom", "opt", "qwen", "mixtral", "gemma"]:
return convert_example_common
return convert_example_common_meta_text
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议不要改动数据集读取逻辑,先将数据预处理符合llm模型格式,然后按照现有的方式加载进来

# Attention! In some cases(ex. ChatGLMv2), tokenized eos_token is not equal to eos_token_id.
if len(tokenized_target_input_ids) < tgt_max_length:
tokenized_target_input_ids += [tokenizer.eos_token_id]
return tokenized_source, tokenized_target_input_ids
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这么处理数据的原因没有明白

model_config.sep_parallel_degree = training_args.sep_parallel_degree
model_config.tensor_parallel_output = True
model_config.seq_length = data_args.max_length

#set RoPE scaling factor
orig_rope_scaling_factor = model_config.rope_scaling_factor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里会被废弃,后续建议使用LongSequenceStrategies https://github.com/PaddlePaddle/PaddleNLP/pull/8076/files

model.config.long_sequence_strategy_type = "attention_strategies"
model.config.long_sequence_strategy_name = "LinearScalingRotaryEmbedding"
model.config.long_sequence_init_args = {"head_dim":head_dim,"max_position_embeddings":max_position_embeddings,"rope_scaling_type":rope_scaling_type,"rope_scaling_factor":rope_scaling_factor}

@@ -458,7 +473,7 @@ def neft_post_hook(module, input, output):

if model_args.lora:
if model_args.lora_path is None:
target_modules = get_lora_target_modules(model)
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
lora_config = LoRAConfig(
target_modules=target_modules,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lora_config本身就有一个参数叫trainable_modules [".*embed.*",".*norm.*"]
https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/peft/lora/lora_config.py#L47

model.recompute_enable()
for param in model.parameters():
if not param.stop_gradient and param.grad is None:
param.clear_gradient()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是干什么用

return_tensors="np",
pad_to_multiple_of=data_args.pad_to_multiple_of,
max_length=max_length,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个改动是为了什么

@@ -553,7 +591,8 @@ def compute_metrics_do_generation(eval_preds):
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
# train_result = trainer.train(resume_from_checkpoint=checkpoint)
train_result = trainer.train()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?这行不要改

return True
return False


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不能直接拉main分支最近的代码,然后把旧代码贴上来

@@ -838,20 +846,34 @@ def forward(
) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism)

def create_attention_mask(input_shape, dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个去掉

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants