Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Testing the lion optimizer #432

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

mitchellnw
Copy link
Contributor

@lucidrains
Copy link
Contributor

@mitchellnw oh nice! yea let us try it out and add your anecdata here

in my mind, a learning rate cooldown is still needed with this technique

@lucidrains
Copy link
Contributor

lucidrains commented Feb 19, 2023

@rwightman have you tried Lion yet for any of the vision transformers training? your voice is one of the more definitive ones out there

@mitchellnw
Copy link
Contributor Author

So far at small scale (short B/32 run, batch size 16k), well tuned lion slightly outperforms AdamW (still tuning AdamW).

AdamW (LR 2e-3, WD 0.2, betas=0.9, 0.95) = 42.1
Lion (LR 2e-4, WD 2, betas=0.9,0.99) = 42.7

Trying to run anbalagous experiments for H but resources have been busy..

@mitchellnw
Copy link
Contributor Author

Plot for B/32

loss_plot

@rwightman
Copy link
Collaborator

@lucidrains inconclusive so far, managed to almost match some recent adamw results for large fine-tune, but took a fair bit of search. I feel unless very resource contained adamw still be go to due to hparam familiarity...

@lucidrains
Copy link
Contributor

So far at small scale (short B/32 run, batch size 16k), well tuned lion slightly outperforms AdamW (still tuning AdamW).

AdamW (LR 2e-3, WD 0.2, betas=0.9, 0.95) = 42.1 Lion (LR 2e-4, WD 2, betas=0.9,0.99) = 42.7

Trying to run anbalagous experiments for H but resources have been busy..

nice! maybe this is the perfect fit then, with the large batch size training needed for clip (in paper they show growing advantage of lion over adamw with increasing batch size, so this lines up)

@lucidrains inconclusive so far, managed to almost match some recent adamw results for large fine-tune, but took a fair bit of search. I feel unless very resource contained adamw still be go to due to hparam familiarity...

good to know! seems like everyone is reporting needing to fiddle around with hparam before seeing comparable results..

@lucidrains
Copy link
Contributor

lucidrains commented Feb 21, 2023

@mitchellnw wanted to thank you for running and sharing this btw! honestly, i was on the fence about this technique, but now i believe it should be used in the large batch size regime

@xiangning-chen
Copy link

xiangning-chen commented Feb 21, 2023

@mitchellnw Thanks for the experiments! I observe that you used betas=0.9, 0.95 for AdamW compared to the default betas=0.9, 0.999. While for Lion it's still the default one betas=0.9, 0.99, could you please try betas=0.95, 0.98?

In our experiments, we use this setting if the beta2 in AdamW is 0.99, though in your case it's a even smaller 0.95.
One benefit of lowering the beta2 in both AdamW and Lion is that with shorter memorization of the historical information, the learning rate can be larger empirically.

Really appreciate it!

@lucidrains
Copy link
Contributor

@xiangning-chen 👋 just heard another positive result this morning from someone trustworthy! 💯

while you are here, have you figured out which learning rate scheduler is optimal with Lion? it seems like it would matter

@xiangning-chen
Copy link

@lucidrains Thanks for sharing the good news!

We always used the same learning rate schedule as AdamW in our experiments including cosine decay, linear decay, and constant (all with 10K steps warmup). We also tried rsqrt decay when pre-training ViT on JFT, but the gain of Lion was lower compared to cosine decay, which we also observed in the ViT proxy task.

So I would say on ViT, the Lion optimizer is better suited for cosine decay, where the learning rate decays to either zero or a very small value, compared to rsqrt decay.

@lucidrains
Copy link
Contributor

@xiangning-chen thank you for your recommendation!

@mitchellnw
Copy link
Contributor Author

Thanks @xiangning-chen will try the other betas, and potentially higher LR when making that change!

When raising LR, would you also raise WD?

@xiangning-chen
Copy link

@mitchellnw Thank you! Actually I would decrease the WD when raising the LR to maintain the effective weight decay strength LR*WD.

@mitchellnw
Copy link
Contributor Author

Thanks, and congrats on the work by the way. Really cool results and quite an interesting optimizer you found!

@mitchellnw
Copy link
Contributor Author

Ran short (20k iterations) for batch size 16k and H/14 on LAION 2b. Not as much room for hparam tuning as the experiments are compute intensive so still finding lion falling a bit short of AdamW. Please let me know which other hparams you'd recommend @xiangning-chen .

Here's what I ran so far:
opt3_lp_h

  1. Blue gets 57.6
  2. Orange gets 56.1
  3. Green gets 54.9

@xiangning-chen
Copy link

xiangning-chen commented Feb 27, 2023

Thanks for the experiments!
Can I know the learning rate schedule and warmup iterations. @mitchellnw
Also, the wd is still 0.2 for AdamW right?
What about lr=4e-4, betas=(0.95, 0.98), wd=1.0 for Lion in order to maintain the same strength?
I feel like the 20K iterations are pretty short, so that would make a difference (the red curve).

@lucidrains
Copy link
Contributor

yea, i'm starting to hear more negative reports coming in unfortunately. the common story i hear is that it converges faster, but generalizes worse

@xiangning-chen
Copy link

@lucidrains May I know on what domains they observe a faster convergence but worse generalization? Thanks!
To me, this appears to be a consequence of using a small learning rate, which leads to faster convergence at the beginning but then getting stuck in local minima. The sign operation in Lion inserts certain randomness, which usually leads to better generalization on our side.

@mitchellnw
Copy link
Contributor Author

@xiangning-chen thanks for the recommendations! Yes 20k is extremely short but sadly these experiments are already very expensive so don't have much other option. Hmm so you'd say just re-run red with LR 4e-4 instead of 5e-4? My guess would be such a small change won't fix red but you'd definitely know best here :)

@xiangning-chen
Copy link

@mitchellnw May I know the warmup iterations and learning rate schedule? Thanks!

@mitchellnw
Copy link
Contributor Author

mitchellnw commented Feb 27, 2023

Yep! 5k warmup iterations (linear warmup) then cosine decay. And weight decay for the AdamW baseline is 0.2.

@xiangning-chen
Copy link

xiangning-chen commented Feb 28, 2023

@mitchellnw Thanks for the information!
I quickly tested on training ViT-B/16 with 20k steps and batch size 4,096. I will also try on CLIP training as well. Here are the results and hyperparameters (I swept over on log3 scale, the betas settings are by default, (0.9, 0.999) for AdamW, and (0.9, 0.99) for Lion):

  • Warmup 5k steps
    • For AdamW, lr=1e-3, wd=0.3, acc=73.0%
    • For Lion, lr=1e-4, wd=3.0, acc=76.18%
  • Warmup 10k steps
    • For AdamW, lr=1e-3, wd=0.3, acc=74.46%
    • For Lion, lr=3e-4, wd=1.0, acc=76.21%

So it seems like Lion works fine on few training steps, I will provide an update on the CLIP training result ASAP.

@mitchellnw
Copy link
Contributor Author

Very interesting, thanks for sharing! A few comments/questions:

  • For AdamW do you think performance could improve by, e.g., moving away from default beta2 to 0.98 or 0.95.
  • For AdamW do you any form of gradient clipping, if so which?
  • Wow, that is incredible accuracy for only 20k steps for such a low batch size! Very nice, what dataset is this? The same used in the paper? Must be very good quality!
  • Overall these look similar to the B/32 results from above where lion performs better! Curious what you find if you try a model like H/14 where I'm currently trying to boost lion accuracy. Appreciate your help in hparam search because the cluster I'm using is now extremely busy so experiments are limited

@xiangning-chen
Copy link

For AdamW do you think performance could improve by, e.g., moving away from default beta2 to 0.98 or 0.95.

If with the same learning rate, I don't think this would make a big difference. But using beta2=0.95 here would definitely helps with the training stability, which means that a larger learning rate can be used without NaN.

For AdamW do you any form of gradient clipping, if so which?

I used gradient clipping 1.0 for both optimizers.

Wow, that is incredible accuracy for only 20k steps for such a low batch size! Very nice, what dataset is this? The same used in the paper? Must be very good quality!

Oh this is not zero-shot accuracy, I just quickly tested on supervised image classification to see whether the short 20k training steps matter. Currently I'm having some permission issues with the internal image-text dataset. As soon as I regain access, I will proceed with the CLIP training.

Overall these look similar to the B/32 results from above where lion performs better! Curious what you find if you try a model like H/14 where I'm currently trying to boost lion accuracy. Appreciate your help in hparam search because the cluster I'm using is now extremely busy so experiments are limited.

Sure, glad to help with the hparam tuning and really hope that Lion can benefit the open clip project!
I will keep you updated about my progress and results.

Thanks you so much!

@xiangning-chen
Copy link

xiangning-chen commented Mar 7, 2023

@mitchellnw I have some updates to share! I discovered that the initial temperature value has an impact, and tuning it has resulted in better performance for Lion, compared to AdamW.

I conducted experiments using base-sized vision and text encoders, with a learning rate schedule of 10K steps for warmup and then cosine decay. The batch size was 16K, and the dataset I used was WebLI.

With initial temperature = 10.0, which is similar to 1/0.07 in OpenCLIP here:

Model Steps Optimizer Lr Wd ImageNet Zero-shot Acc Training Error Val Error
B/32 20K AdamW 1e-3 0.3 49.94 1.70 1.69
B/32 20K Lion 3e-4 1.0 49.47 1.69 1.71
B-16 20K AdamW 1e-3 0.3 56.96 1.25 1.24
B-16 20K Lion 3e-4 1.0 55.84 1.26 1.27
B-16 50K AdamW 1e-3 0.1 64.01 0.80 0.80
B-16 50K Lion 3e-4 0.33 63.74 0.82 0.83

Lion indeed performed worse than AdamW for both ImageNet zero-shot accuracy and validation error.
However, with initial temperature = 30.0:

Model Steps Optimizer Lr Wd ImageNet Zero-shot Acc Training Error Val Error
B/32 20K AdamW 1e-3 0.3 50.19 1.70 1.68
B/32 20K Lion 3e-4 1.0 50.75 1.63 1.62
B-16 20K AdamW 1e-3 0.3 57.23 1.23 1.22
B-16 20K Lion 3e-4 1.0 57.77 1.17 1.18
B-16 50K AdamW 1e-3 0.1 64.67 0.79 0.80
B-16 50K Lion 3e-4 0.33 65.04 0.76 0.77

Lion becomes the clear winner. Note that initial temperature = 30.0 is also a more optimal setting for both optimizers.
Learning curves are shown below for ViT-B/16 with 20K steps.

  • Training Error:
    train
  • Validation Error:
    val
  • ImageNet Zero-shot:
    inet

@mitchellnw
Copy link
Contributor Author

This is super interesting @xiangning-chen, thanks a lot for the exhaustive exploration! I would not have thought of modifying temperature, how did you think of this?

I am really looking forward to trying this modification in our setting, looks very promising! However, it may now take some time as the cluster is extremely busy for the month of March so I'm not able to run jobs.

However, rest assured I will get to this -- and if anyone else wants to try this out first for OpenCLIP on LAION-5b please do get in touch. Otherwise I will send the latest updates here when the cluster is back to regular in April :).

Thanks again!

@xiangning-chen
Copy link

I would not have thought of modifying temperature, how did you think of this?

I tracked the temperature value throughout the training, and found that Lion learns this value pretty slow at the beginning, that's why I tried to increase the initial value.
Screenshot 2023-03-06 at 21 34 31

However, it may now take some time as the cluster is extremely busy for the month of March so I'm not able to run jobs.

No worries, thanks for trying it! Hope this can be helpful and also facilitate the usage of Lion.

@lucidrains
Copy link
Contributor

lucidrains commented Mar 7, 2023

@xiangning-chen oh this is really interesting

what initial temperature value did the contrastive learning networks (LiT and BASIC) you tested on have?

@xiangning-chen
Copy link

@lucidrains I used an initial temperature 10 in the paper. But in LiT and BASIC, the vision tower is loaded from a pre-trained ckpt and is fixed during training, while in OpenCLIP both vision and language towers are initialized from scratch.

@xiangning-chen
Copy link

@mitchellnw I further validate on the large and giant size CLIP, each with 20K steps (10K steps warmup then cosine decay) and initial temperature = 30.0. When training g/14, I switched to the Adafactor in "Scaling Vision Transformers" paper due to memory issue.

Model #Params Optimizer Lr Wd ImageNet 0-shot Acc Val Error img/sec/core
L/16 653M AdamW 1e-3 0.3 62.61 0.94 198
L/16 653M Lion 1e-4 3.0 63.13 0.92 227
g/14 2.1B Adafactor 1e-3 0.3 65.62 0.79 60
g/14 2.1B Lion 1e-4 3.0 66.25 0.75 66

From the table, Lion is still the better one, and it also offers a runtime improvement of 10%-15%.

@mitchellnw
Copy link
Contributor Author

Great! ETA for continuing to test in openclip is late march or early april, looking forward!

@lucidrains
Copy link
Contributor

lucidrains commented Mar 13, 2023

i am readying an 8-bit version of Lion over at the bitsandbytes repository, and hope to get it merged some time end of this month

@xiangning-chen
Copy link

@mitchellnw Just wondering do you have any update for using Lion with a higher initial temperature? Thanks!

@mitchellnw
Copy link
Contributor Author

Thanks for the reminder! I was expectedly behind schedule but just launched the jobs!

@mitchellnw
Copy link
Contributor Author

mitchellnw commented Apr 15, 2023

opt3_lp_h_lion

That did help narrow the gap a lot! AdamW and lion are now within 1 percentage point of each-other. Do you think I should go even higher on init temp? Feel free to let me know exactly what hparam setting you'd try next and I'll run it.

  • AdamW* (2e-3, 0.9, 0.98), wd 0.2 : 57.57
  • Lion (2e-4, 0.95, 0.98), wd 2 : 56.05
  • AdamW* (2e-3, 0.9, 0.98), wd 0.2, init temp 30 : 57.33
  • Lion (2e-4, 0.95, 0.98), wd 2, init temp 30 : 56.98

* I should also mention -- the AdamW I'm using here uses "update clipping" from AdaFactor to get rid of loss spikes. To use vanilla AdamW I need to decrease beta2 a lot for stability. The best vanilla AdamW result is AdamW (2e-3, 0.9, 0.9), wd 0.2: 57.29 -- so within 0.3pp of lion.

@mitchellnw mitchellnw marked this pull request as draft April 16, 2023 17:48
@xiangning-chen
Copy link

Thanks for the result! I used gradient clipping in my experiments but not update clipping.
It seems like the stability issue is a big concern, as both betas in your vanilla AdamW setting are 0.9.
Do you think whether the update clipping plays a big role here?

@mitchellnw
Copy link
Contributor Author

I do find update clipping helps compared to gradient clipping, but also without update clipping and without gradient clipping is still fairly good (57.33). But yes, as you point out the beta2 is very low here (0.9) for stability reasons. With lion we don't observe any stability problems though.

@mitchellnw
Copy link
Contributor Author

not strictly on topic but I thought you would appreciate hearing @xiangning-chen -- i'm testing out the lion optimizer for a text task and finding really nice performance there.

@xiangning-chen
Copy link

@mitchellnw Thanks for letting me know, that's really good news!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants