Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is model. generate supported during the training process? #30713

Open
2 of 4 tasks
sunxiaojie99 opened this issue May 8, 2024 · 3 comments
Open
2 of 4 tasks

Is model. generate supported during the training process? #30713

sunxiaojie99 opened this issue May 8, 2024 · 3 comments

Comments

@sunxiaojie99
Copy link

sunxiaojie99 commented May 8, 2024

System Info

torch 2.0.0
peft 0.4.0
transformers 4.38.0

Who can help?

@ArthurZucker @younesbelkada @SunMar

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

build code

base_model =MistralForCausalLM.from_pretrained(model_name_or_path)
lora_config = LoraConfig(
                    base_model_name_or_path=model_name_or_path,
                    task_type="FEATURE_EXTRACTION",
                    r=model_args.lora_r,
                    lora_alpha=model_args.lora_alpha,
                    lora_dropout=model_args.lora_dropout,
                    target_modules=model_args.lora_target_modules.split(','),
                    inference_mode=False
                )
lora_model = get_peft_model(base_model, lora_config)
lora_model.print_trainable_parameters()  # trainable params: 20,971,520 || all params: 7,262,703,616 || trainable%: 0.2887563792882719

train code (success)

self.encoder=lora_model
def encode_general(self, qry):
        inputs = {
            'input_ids': qry['input_ids'],
            'attention_mask': qry['attention_mask'],
        }
        psg_out = self.encoder.model(**inputs, output_hidden_states=True)

        reps = psg_out.hidden_states[-1][:, -1, :]
        return reps

train code (fail)

self.encoder=lora_model
def encode_generate(self, qry, max_gen_len):
        inputs = {
            'input_ids': qry['input_ids'],
            'attention_mask': qry['attention_mask'],
        }
        max_new_tokens = max_gen_len
        generation_output = self.encoder.generate(**inputs, return_dict_in_generate=True, max_new_tokens=max_new_tokens, output_hidden_states=True)
        return generation_output.hidden_states[0][-1][:, -1, :]

the error message is:

Traceback (most recent call last):
  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/root/paddlejob/workspace/env_run/llm-index/src/tevatron/retriever/driver/train_cot.py", line 126, in <module>
    main()
  File "/root/paddlejob/workspace/env_run/llm-index/src/tevatron/retriever/driver/train_cot.py", line 119, in main
    trainer.train()  # TODO: resume training
  File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 1624, in train
    return inner_training_loop(
  File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 1961, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/root/paddlejob/workspace/env_run/llm-index/src/tevatron/retriever/trainer.py", line 55, in training_step
    loss = super(TevatronTrainer, self).training_step(*args) / self._dist_loss_scale_factor
  File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 2911, in training_step
    self.accelerator.backward(loss)
  File "/usr/local/lib/python3.8/dist-packages/accelerate/accelerator.py", line 2116, in backward
    self.deepspeed_engine_wrapped.backward(loss, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/accelerate/utils/deepspeed.py", line 166, in backward
    self.engine.backward(loss, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/engine.py", line 1976, in backward
    self.optimizer.backward(loss, retain_graph=retain_graph)
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/zero/stage3.py", line 2213, in backward
    self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
    scaled_loss.backward(retain_graph=retain_graph)
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Expected behavior

Successful training process

@younesbelkada
Copy link
Contributor

Hi @sunxiaojie99
Unfortunately you can't apply backpropagation during generate because that method is wrapped under torch.no_grad() context manager:


If you really want to compute gradient while generating text, you will need to come up with a custom text generation loop, for example this one: https://gist.github.com/ArthurZucker/5dc54a3fb443e979fac437e5df7c800b from @ArthurZucker by making sure to remove torch.no_grad() statement

@sunxiaojie99
Copy link
Author

https://gist.github.com/ArthurZucker/5dc54a3fb443e979fac437e5df7c800b

thank you very much!

@younesbelkada
Copy link
Contributor

You are welcome @sunxiaojie99 !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants