-
Notifications
You must be signed in to change notification settings - Fork 25.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEAT: Add Badam optimizer #30692
base: main
Are you sure you want to change the base?
FEAT: Add Badam optimizer #30692
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice job :) Very straightforward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work adding this!
At the moment, I think there's some work needed to move the handling logic into different methods to prevent leakiness i.e. not all of these methods need to know about badam.
p.s. I can't read badam and not think about this Kylie banger
|
||
if use_badam: | ||
self.optim = "badam_" + self.optim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having to remove and then add this back seems both convoluted and an indication that there's something funny about the implementation. Specifically, it seems like something which should be handled on the OptimizerNames
side
if badam_kwargs is not None: | ||
from badam import BlockOptimizer | ||
|
||
self.optimizer = BlockOptimizer( | ||
base_optimizer=self.optimizer, | ||
named_parameters_list=list(opt_model.named_parameters()), | ||
block_prefix_list=None, | ||
**badam_kwargs, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is another indication of a peculiar pattern - why set self.optimizer
and then overwrite here? Instead, the correct optimizer class and kwargs should be handled in Trainer.get_optimizer_cls_and_kwargs(self.args, opt_model)
if args.optim_args: | ||
for mapping in args.optim_args.replace(" ", "").split(","): | ||
key, value = mapping.split("=") | ||
optim_args[key] = value | ||
if "badam_" in key and use_badam: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there ever a case when use_badam
is False
and badam_
is in the key?
badam_optim_args[key.replace("badam_", "")] = value | ||
else: | ||
optim_args[key] = value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't need to do this string manipulation and have these parallel optim_args. This doesn't scale well if we want to add other optimizers which accept some of the previous optim args
if args.optim_args: | ||
for mapping in args.optim_args.replace(" ", "").split(","): | ||
key, value = mapping.split("=") | ||
optim_args[key] = value | ||
if "badam_" in key and use_badam: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - check the bool flag first in the and check, it'll be faster
if "badam_" in key and use_badam: | |
if use_badam and "badam_" in key: |
What does this PR do?
Fixes: #30308
This PR adds Badam optimizer to transformers Trainer API
TODOs:
cc @amyeroberts @muellerzr