Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

关于模型训练 #24

Open
chencn2020 opened this issue Mar 18, 2024 · 2 comments
Open

关于模型训练 #24

chencn2020 opened this issue Mar 18, 2024 · 2 comments

Comments

@chencn2020
Copy link

chencn2020 commented Mar 18, 2024

您好

请问你们在训练的时候,有没有遇到过训练卡在第一个epoch,但是GPU占用为100%的情况

一开始以为是服务器的问题,但只要把MASK Token部分代码删掉,就可以正常训练

if cur_input_ids.numel() > 0:
                if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
                    mask_idx = torch.nonzero(cur_input_ids==self.tokenizer.convert_tokens_to_ids(['<mask>'])[0])
                    _l = 0
                    for i, idx in enumerate(mask_idx):
                        cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[_l:idx[0]]).detach())
                        ## mask
                        cur_new_input_embeds.append(mask_feats[batch_idx][i:i+1].detach())
                        ## pos
                        cur_new_input_embeds.append(pos_feats[batch_idx][i:i+1].detach())
                        if labels is not None:
                            cur_labels[idx[0]:idx[0]+2] = torch.full((2,), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)
                        _l = idx[0]+2
                    if _l< len(cur_input_ids):
                        cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[_l:]).detach())

                else:
                    mask_idx = torch.nonzero(cur_input_ids==self.tokenizer.convert_tokens_to_ids(['<mask>'])[0])
                    assert len(mask_idx) == len(mask_feats[batch_idx]), "mask num not equal to mask feats"
                   
                    _l = 0
                    for i, idx in enumerate(mask_idx):
                        cur_raw_new_input_embeds = self.get_model().embed_tokens(cur_input_ids[_l:idx[0]])
                        cur_new_input_embeds.append(cur_raw_new_input_embeds)
                        ## mask
                        cur_new_input_embeds.append(mask_feats[batch_idx][i:i+1].to(cur_raw_new_input_embeds.dtype))
                        ## pos
                        cur_new_input_embeds.append(pos_feats[batch_idx][i:i+1].to(cur_raw_new_input_embeds.dtype))

                        if labels is not None:
                            cur_labels[idx[0]:idx[0]+2] = torch.full((2,), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)

                        _l = idx[0]+2
                    if _l< len(cur_input_ids):
                        cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[_l:]))

                if labels is not None:
                    cur_new_labels.append(cur_labels)
@LiWentomng
Copy link
Collaborator

您好@chencn2020
我们也遇到过类似的情况,这通常发生在训练数据量很大的情况下。我们猜测可能和服务器的性能限制有关。
如果您有什么好的想法,欢迎进一步交流或者提pr。

@xuzf-git
Copy link

image
您好,我在自己的数据集上基于LoRA微调进行satge3的训练,使用单张A100,batch size是16,请问这个速度正常吗?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants