New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding adapters to SpeechBrain (Code from Samsung AI Center Cambridge) #2534
base: develop
Are you sure you want to change the base?
Conversation
I don't think this is quite the right approach because I don't think it allows for stopping/restarting which is part of the point of checkpointing. Instead, the checkpointer should store the LoRA'd model, not the pretrained model. Ideally it would even only store the LoRA weights (and any updated weights) and not the whole model, making for very small checkpoints and faster saving. Example: add_adapters: !name:speechbrain.lobes.models.Adapters.add_adapters_to_linear_in_model
adapter_class: !name:speechbrain.lobes.models.Adapters.HoulsbyAdapterLinear
projection_size: 32
pretrainer: !new:speechbrain....Pretrainer
loadables:
transformer: !ref <Transformer> # Load the pretrained model.
run_on_main(hparams["pretrainer"].collect_files)
hparams["pretrainer"].load_collected()
hparams["add_adapters"](hparams["Transformer"])
asr_brain = ASR(
modules=hparams["modules"],
opt_class=hparams["Adam"],
hparams=hparams,
run_opts=run_opts,
checkpointer=hparams["checkpointer"], # Checkpointer loads LoRA weights only and applies them
)
asr_brain.fit(
asr_brain.hparams.epoch_counter, train_data, valid_data,
) |
It does allow for stop and restart because you are altering the object i.e. the checkpointer keeps track of it! The only problem is indeed that you store the whole model, however, I don't think it's an issue because in-fine you may simply don't know where to put the pre-trained adapters in the model if they are not applied to every linear layer for instance. I'd be happy to see a functionnal example of something else though. |
Tbh I think PEFT handles this perfectly, perhaps we should lift their code wholesale. |
You mean depend on another Huggingface library? |
My opinion is we should just add it as a dependency, but I understand the objections to it. So instead we could just copy the parts of the code that make sense into speechbrain. |
If you could give me a neat example of an integration of PEFT, I could be convinced. |
The problem with peft is when we want to load the model from the speechbrain checkpoint. It is a mess to make it work and also it could cause the problem when using different version of peft. But maybe we could find a way to do this in a cleaner way. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this go in nnet
rather than lobes
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question. it's unclear because Adapters can be considered as "entire models" coming from the literature. But I think I agree that they can also be seen as small components. I'd be happy if you could help with the get_model like for PEFT. From your previous PR, I liked the fact that we can rely on the larger Adapter base from PEFT -- I am wondering if there isn't a way to combine both ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I personally like the fact that with my function, you can actually specify what part of the Brain.modules or whatever model you want to put Adapters on. But I'd be happy to see something else.
What does this PR do?
Based on this #2526, this PR is a first attempt at adding Adapters to any SB model. This will only work if PreTrainer is used and not checkpointer. Indeed, the checkpointer will try to reload after the state_dict has been modified. So you need to 1. Instanciate the Brain, 2. Call the PreTrainer; 3. add the adapters; 4. call fit. An example is:
`
`