Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Add Badam optimizer #30692

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented May 7, 2024

What does this PR do?

Fixes: #30308

This PR adds Badam optimizer to transformers Trainer API

TODOs:

  •  Add ratio optimizer
  • Add layer optimizer
  • Add docs

cc @amyeroberts @muellerzr

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job :) Very straightforward.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work adding this!

At the moment, I think there's some work needed to move the handling logic into different methods to prevent leakiness i.e. not all of these methods need to know about badam.

p.s. I can't read badam and not think about this Kylie banger

Comment on lines +1628 to +1630

if use_badam:
self.optim = "badam_" + self.optim
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having to remove and then add this back seems both convoluted and an indication that there's something funny about the implementation. Specifically, it seems like something which should be handled on the OptimizerNames side

Comment on lines +1083 to +1091
if badam_kwargs is not None:
from badam import BlockOptimizer

self.optimizer = BlockOptimizer(
base_optimizer=self.optimizer,
named_parameters_list=list(opt_model.named_parameters()),
block_prefix_list=None,
**badam_kwargs,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is another indication of a peculiar pattern - why set self.optimizer and then overwrite here? Instead, the correct optimizer class and kwargs should be handled in Trainer.get_optimizer_cls_and_kwargs(self.args, opt_model)

if args.optim_args:
for mapping in args.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optim_args[key] = value
if "badam_" in key and use_badam:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there ever a case when use_badam is False and badam_ is in the key?

Comment on lines +1151 to +1153
badam_optim_args[key.replace("badam_", "")] = value
else:
optim_args[key] = value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't need to do this string manipulation and have these parallel optim_args. This doesn't scale well if we want to add other optimizers which accept some of the previous optim args

if args.optim_args:
for mapping in args.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optim_args[key] = value
if "badam_" in key and use_badam:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - check the bool flag first in the and check, it'll be faster

Suggested change
if "badam_" in key and use_badam:
if use_badam and "badam_" in key:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Badam support
4 participants