New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEATURE: ADD LISA ALGORITHM] #3103
Conversation
fixes: #3087 |
Takes OptimalScale/LMFlow#726 |
When combining lisa with multiple GPUs, Zero3 and gradient checkpointing, it comes to the following error: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:7 and cpu! (when checking argument for argument tensors in method wrapper_CUDA_cat) return torch._C._nn.flatten_dense_tensors(tensors) single_grad_partition = self.flatten(self.averaged_gradients[sub_group_id]).to... |
I came up with the code below. The def on_train_epoch_start(self, trainer: "L.Trainer", pl_module: "pl.LightningModule"):
if trainer.current_epoch % self.epoch_interval == 0:
self.switch_active_layers()
pl_module.optimizer_fn = torch.optim.Adam
trainer.strategy.setup_optimizers(trainer) |
I have conducted experiments on llama2-7b using full, lisa_2, lisa_32 methods. From the image above, you can see that the train loss curve decreases and full is the same as lisa_32. The latest code borrowed some impl from lmflow and axolotl. Some impl details are purified and debug option is given. Hope this will be merged. |
I tried this and noticed that fine-tuning Config
System Info
|
When I used LISA to fintune Llama-2-7b on alpaca-gpt4-en with one a100 80G,the used memory increased sharply and exceeded 80G. I want to know how to solve this problem... Config: |
@neteroster Hello. I have the same problem with you, have you solved it? |
@lovekdl Not yet. |
I have conducted experiments on llama2-7b using full, lisa_2, lisa_32 methods. From the image above, you can see that the train loss curve decreases and full is the same as lisa_32. when merge this pr |
What does this PR do?
NEW FEATURE:
ADD LISA ALGORITHM, SEE: https://arxiv.org/abs/2403.17919
Before submitting