-
Notifications
You must be signed in to change notification settings - Fork 809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Saving/loading model with a checkpoint callback #2367
Comments
Hi @tRosenflanz, Not sure why the model = DLinearModel(input_chunk_length=4, output_chunk_length=1, save_checkpoints=True, work_dir="../custom_dir") As for the monitored metrics, based on Pytorch-Lightning documentation, it should be possible to use other functions than the loss to identify the best checkpoint since all the from darts.models import DLinearModel
from torchmetrics import ExplainedVariance
from pytorch_lightning.callbacks import ModelCheckpoint
# Darts prepend "val_" and "train_" to the torch_metrics entries name
checkpoint_callback = ModelCheckpoint(monitor='val_ExplainedVariance')
# Add the callback
model = DLinearModel(
input_chunk_length=4,
output_chunk_length=1,
save_checkpoints=True,
torch_metrics=ExplainedVariance(),
pl_trainer_kwargs={"callbacks":[checkpoint_callback]}
)
model.fit(train, val_series=val, epochs=10)
new_model = DLinearModel.load_from_checkpoint(model.model_name, best=True) Please note that specifying the Manually saving the model as you described is probably the most practical way to copy the checkpoints to the desired place. Let me know if it helps. |
I think the automatic concatenation of different subpaths into work_dir and interactions with custom training confused me a bit. I am switching from pytorch_forecasting where I had to define my own trainer so the Darts way of handling it confused me at first. I see no issue with the method you provide. |
Is your feature request related to a current problem? Please describe.
Current torch model checkpointing logic is quite rigid. It only allows to track loss rather than other metrics and uses somewhat hardcoded directories which makes restoring the best model for a given metric challenging.
Describe proposed solution
Either or/and
Describe potential alternatives
I worked around this issue by abusing os.path.join behavior with absolutes paths (must start with
/
)If a segment is an absolute path (which on Windows requires both a drive and a root), then all previous segments are ignored and joining continues from the absolute path segment.
Additional context
Clearly the solution above is not ideal since it abuses somewhat hidden behavior of os.path.join and knowledge of the method internals. Perhaps it is the responsibility of the user to enable proper parameters when a custom trainer is used but
load_from_checkpoint
could accept explicit paths for model and checkpoint and skip the current directory/naming logic when absolute paths are givenThe text was updated successfully, but these errors were encountered: