Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][wenet/LLM] support LLMs #2460

Open
wants to merge 43 commits into
base: main
Choose a base branch
from
Open

[WIP][wenet/LLM] support LLMs #2460

wants to merge 43 commits into from

Conversation

Mddct
Copy link
Collaborator

@Mddct Mddct commented Apr 7, 2024

为下一步的SpeechLLM 打基础

TODO

  • make it works

    • fintune
    • pretrain
    • dataset
      • sft
      • pretrain
    • special tokens and stop tokens redesign
    • generate
  • convert some model

    • qwen
    • Llama 3
      • 8b
      • 8b-it
      • 70b
      • 70b-it
    • gemma
      • 2b
      • 7b
      • 2b-it
      • 7b-it
      • Code 2B && 7B
    • internlm
  • Llama3 70b 模型并行度为8, 分别在attention的q,k,v 和feed fowrad weight上进行了col row等的切分,需要引入fairscale, 做模型并行。 并且官方给了8个pt, 每个16G左右。

TODO

  • 其他pr中引入 model_parallel, fairscale,

@Mddct Mddct mentioned this pull request Apr 7, 2024
24 tasks
@Mddct
Copy link
Collaborator Author

Mddct commented Apr 12, 2024

为什么不把embeding和out 放到decoderonly里边?

其他模态的注入是从embeding开始的,保持decoder only 有embeding的入参。

如果embeing和out share weight,fsdp 需要embeding 和out 在同一个level上,

我们经常会扩充词表,resize embed 和resize out,放最外层不影响decoderonly

@Mddct Mddct changed the title [WIP text/LLM] support LLMs [WIP][text/LLM] support LLMs Apr 12, 2024
@Mddct
Copy link
Collaborator Author

Mddct commented Apr 14, 2024

gemma 精度测试

# configs = {"decoder": "decoder_only", "output_dim": 256000, "model_conf": {}}

import torch
from wenet.text.LLM.script.convert_gemma_to_wenet_config_and_ckpt import (
    get_config_for_2b, get_config_for_7b)
from wenet.utils.init_model import init_model

from gemma.model import GemmaForCausalLM
from gemma.config import (get_config_for_2b as google_2b_config_fn,
                          get_config_for_7b as google_7b_config_fn)

import argparse


def get_args():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument(
        '--gemma_ckpt',
        required=True,
        help='https://www.kaggle.com/models/google/gemma/frameworks/pyTorch')
    parser.add_argument(
        '--gemma_tokenizer',
        required=True,
        help='https://www.kaggle.com/models/google/gemma/frameworks/pyTorch')

    parser.add_argument(
        '--wenet_gemma_ckpt',
        required=True,
        help='https://www.kaggle.com/models/google/gemma/frameworks/pyTorch')
    parser.add_argument('--model_size', type=str, required=True)
    args = parser.parse_args()
    return args


args = get_args()
args.jit = False

layers = 18 if args.model_size == '2b' else 28
if args.model_size == '2b':
    config = get_config_for_2b()
else:
    config = get_config_for_7b()
model_conf = {
    'model': 'causal_lm',
    'output_dim': config.vocab_size,
    'decoder': 'decoder_only',
    'tokenizer_conf': {
        "special_tokens": {
            'sos': 0,
            'eos': 1
        }
    }
}
decoder_conf = {}
decoder_conf['n_kv_head'] = config.num_key_value_heads
decoder_conf['head_dim'] = config.head_dim
decoder_conf['hidden_size'] = config.hidden_size
decoder_conf['attention_heads'] = config.num_attention_heads
decoder_conf['linear_units'] = config.intermediate_size
decoder_conf['num_blocks'] = layers

decoder_conf['max_position_embeding'] = 8192
decoder_conf['activation_type'] = 'gelu'
decoder_conf['gelu_approximate'] = 'tanh'
decoder_conf['norm_eps'] = config.rms_norm_eps
decoder_conf['use_sdpa'] = True
model_conf['decoder_conf'] = decoder_conf
model_conf['model_conf'] = {}

args.checkpoint = args.wenet_gemma_ckpt
model, _ = init_model(args, model_conf)
model.eval()

# get google gemma model
if args.model_size == '2b':
    google_config = google_2b_config_fn()
else:
    google_config = google_7b_config_fn()
google_config.tokenizer = args.gemma_tokenizer

google_gemma = GemmaForCausalLM(google_config)
google_gemma.load_weights(
    args.gemma_ckpt)
google_gemma.eval()
scale = google_config.hidden_size

batch_size = torch.randint(2, 10, ())
seq_len = torch.randint(3, 20, ())
text = torch.randint(0, config.vocab_size, (batch_size, seq_len))


def google_forward(google_gemma,
                   batch_size,
                   token_ids,
                   seq_len,
                   scale,
                   layers=18):

    google_freqs_cis = google_gemma.freqs_cis
    google_emb = google_gemma.embedder
    google_gemma = google_gemma.model

    input_positions_tensor = torch.arange(0, seq_len)
    google_freqs_cis = google_freqs_cis.index_select(0, input_positions_tensor)
    google_hidden_states = google_emb(token_ids)
    google_hidden_states = google_hidden_states * (scale**0.5)
    # mask_tensor = torch.full((2, 1, 10, 10), -2.3819763e38).to(torch.float)
    mask_tensor = torch.full((batch_size, 1, seq_len, seq_len),
                             0).to(torch.float)
    kv_caches = []
    for _ in range(layers):
        size = (batch_size, seq_len, google_config.num_key_value_heads,
                google_config.head_dim)
        k_cache = torch.zeros(size=size)
        v_cache = torch.zeros(size=size)
        kv_caches.append((k_cache, v_cache))
    google_output = google_gemma(
        google_hidden_states,
        google_freqs_cis,
        input_positions_tensor,
        kv_caches,
        mask_tensor,
    )
    google_output = torch.matmul(google_output, google_emb.weight.T)
    return google_output


def wenet_forward(wenet_model, batch_size, token_ids, seq_len, layers=18):
    hidden_states = wenet_model.embed(token_ids)
    wenet_kv_caches = []
    for _ in range(layers):
        size = (0, 0, 0, 0)
        k_cache = torch.zeros(size=size)
        v_cache = torch.zeros(size=size)
        wenet_kv_caches.append((k_cache, v_cache))

    att_mask_tensor = torch.ones(batch_size,
                                 seq_len,
                                 seq_len,
                                 dtype=torch.bool)
    wenet_output, _ = model.decoder(hidden_states,
                                    att_mask_tensor.squeeze(1),
                                    kv_caches=wenet_kv_caches)

    wenet_output = model.out(wenet_output)
    return wenet_output


wenet_output = wenet_forward(model, batch_size, text, seq_len, layers)
google_output = google_forward(google_gemma, batch_size, text, seq_len, scale,
                               layers)

print(wenet_output)
print(google_output)
assert torch.allclose(wenet_output, google_output)

@Mddct Mddct force-pushed the Mddct-llm branch 4 times, most recently from 1ea0839 to 64ff835 Compare April 25, 2024 12:16
@Mddct
Copy link
Collaborator Author

Mddct commented Apr 26, 2024

sft:
2b gemma fsdp zero3
截屏2024-04-26 15 54 46

@Mddct Mddct force-pushed the Mddct-llm branch 2 times, most recently from 31a8dcd to 782998d Compare April 27, 2024 02:53
@Mddct
Copy link
Collaborator Author

Mddct commented Apr 28, 2024

generate in batch way:

gemma:
截屏2024-04-28 19 15 46

llama:

截屏2024-04-30 11 41 56

@Mddct
Copy link
Collaborator Author

Mddct commented Apr 30, 2024

解释下这里为什么要把shape 变成[bs, seq_len,head, head_dim]

https://github.com/wenet-e2e/wenet/blob/9805ed68638f711b6fda17627efb7aa918ce6870/wenet/transformer/attention.py#L637-#L651

来自gpt4的解释:
110971714449428_ pic

实测[bs, seq_len,head, head_dim], 对head_dim 上apply pos等操作要慢于[bs,head,seq_len, head_dim]

6s vs 2s (长度为300)

ref: https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L256

所以其他xxx attention 是否也需要有对应修改?

@fclearner
Copy link
Contributor

fclearner commented Apr 30, 2024

周神,torch官方也有个llama微调的代码:https://github.com/pytorch/torchtune

@Mddct
Copy link
Collaborator Author

Mddct commented Apr 30, 2024

周神,torch官方也有个llama微调的代码:https://github.com/pytorch/torchtune

嗯 这个有看过。 不过我们最终目的不是llm 而是为了语音理解大模型和语音合成

而且大模型训练 有自己的设计原则和技巧 我们需要把优秀的组件 继承过来

@Mddct
Copy link
Collaborator Author

Mddct commented May 31, 2024

该pr会拆分成以下加个pr

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants