Automatically resumable training in learner class #4020
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
TLDR
This PR introduces granular resumable training in
Learner
class. Specifically, the Learner can resume training from the exact epoch and iteration that a checkpoint was saved at. The Checkpoint file saves this info (along with model and optimizer states). This way, when a user invokesLearner.load
, the Learner automatically resumes training from the last "saved"n_epoch
andn_iter
.In Summary
Learner.save
saves the iteration info along withmodel
andopt
states. Specifically, the epoch (Learner.epoch
) and the iteration (Learner.iter
) in the epoch. Similarly,Learner.load
checks for savedepoch
anditer
.Learner.load
is invoked,Learner.fit
retrieves info on whichepoch
anditer
to resume training on.SkipToEpoch
has been modified toSkipToIter
, and skips training to theiter
th iteration in theepoch
th epoch.Problem Statement
I think that the Learner class is designed for mini-scale training on a local GPU, and does not support large-scale training without having to write a lot of custom code for housekeeping.
(FYI, I am currently training an LLM using fastai)
One big problem I faced was that I could not resume the training if my hardware suddenly failed. Sure, there's the
start_epoch
argument inLearner.fit
. But most likely for large-scale training (eg for LLMs), an epoch itself is EXTREMELY large, and it makes more sense to be able to resume from a specific iteration in a specific epoch.Ideally, the learner should automatically resume training from the last saved epoch and iteration. Therefore the checkpoint file should save the current epoch and iter info along with model and opt states. Learn.fit should automatically resume from this iteration, UNLESS the user specifically specifies
start_epoch
(and by extension,start_iter
) inLearner.fit
.Code
Note: Please suggest improvement in the code structure to make the code cleaner and style-compliant. I'm writing obvious issues that MAY be problematic. Also, let me know what all unit tests, jupyter notebook experiments and documentations need to be written. I'll be happy to spend more time on it.
save
patch function passes an additional dictionary containing currentepoch
anditer
, to the save_model function, IFwith_iter
is True. Ifwith_iter
is True,with_opt
must also be True.save_model
saves a dictionary (using torch.save) containingmodel
,opt
anditer
(where iter is the dictionary containing epoch and iter info)load
patch function saves iter info (dictionary) in a new Learner variableLearner.resumeIter
.load_model
function not only loads model and opt states, but also returns iter info to the above load function.NOTE: not sure how to modify Learner.epoch and Learner.iter by reference. So I had to introduce return function. Please suggest a better way.
Learner.load
checks for self.resumeIter variable, and initializesSkiptoIter
(modifiedSkipToEpoch
) Callback. However,start_epoch
and a new argumentstart_iter
override the loaded epoch and iter values.SkipToIter
essentially adds abefore_batch
methods apart from thebefore_epoch
method, which ensures that Training is skipped until the desired iteration.