Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saving/loading model with a checkpoint callback #2367

Closed
tRosenflanz opened this issue May 3, 2024 · 2 comments
Closed

Saving/loading model with a checkpoint callback #2367

tRosenflanz opened this issue May 3, 2024 · 2 comments
Labels
q&a Frequent question & answer

Comments

@tRosenflanz
Copy link
Contributor

tRosenflanz commented May 3, 2024

Is your feature request related to a current problem? Please describe.
Current torch model checkpointing logic is quite rigid. It only allows to track loss rather than other metrics and uses somewhat hardcoded directories which makes restoring the best model for a given metric challenging.

Describe proposed solution
Either or/and

  1. allow metric/args override for the checkpointing callback when save_checkpoints is enabled,
  2. Modify load_from_checkpoint to be more flexible

Describe potential alternatives
I worked around this issue by abusing os.path.join behavior with absolutes paths (must start with /)
If a segment is an absolute path (which on Windows requires both a drive and a root), then all previous segments are ignored and joining continues from the absolute path segment.

#load best
import os
cwd = os.getcwd()
checkp_dir = os.path.dirname(checkpoint_callback.best_model_path)
m.save(os.path.join(checkp_dir,"_model.pth.tar"))
tft = TSMixerModel.load_from_checkpoint(
    "feel free to put anything here because it actually doesn't matter",
    work_dir=os.path.join(cwd,checkp_dir),  #absolute path, ignores internal path prefixes
    file_name=os.path.join(cwd,checkpoint_callback.best_model_path) #absolute path, ignores internal path prefixes
)

#copy to destination can be done with just save
name = 'darts_model'
destination = os.path.join(OUTPUT_DIR,name)
tft.save(destination)

Additional context
Clearly the solution above is not ideal since it abuses somewhat hidden behavior of os.path.join and knowledge of the method internals. Perhaps it is the responsibility of the user to enable proper parameters when a custom trainer is used but load_from_checkpoint could accept explicit paths for model and checkpoint and skip the current directory/naming logic when absolute paths are given

@tRosenflanz tRosenflanz added the triage Issue waiting for triaging label May 3, 2024
@madtoinou
Copy link
Collaborator

madtoinou commented May 7, 2024

Hi @tRosenflanz,

Not sure why the work_dir argument is not working for you, it should allow you to indicate a custom path for the automatic checkpoints when creating the torch model?

model = DLinearModel(input_chunk_length=4, output_chunk_length=1, save_checkpoints=True, work_dir="../custom_dir")

As for the monitored metrics, based on Pytorch-Lightning documentation, it should be possible to use other functions than the loss to identify the best checkpoint since all the torch_metrics are actually logged into the trainer:

from darts.models import DLinearModel
from torchmetrics import ExplainedVariance
from pytorch_lightning.callbacks import ModelCheckpoint

# Darts prepend "val_" and "train_" to the torch_metrics entries name 
checkpoint_callback = ModelCheckpoint(monitor='val_ExplainedVariance')

# Add the callback
model = DLinearModel(
    input_chunk_length=4,
    output_chunk_length=1,
    save_checkpoints=True,
    torch_metrics=ExplainedVariance(), 
    pl_trainer_kwargs={"callbacks":[checkpoint_callback]}
)

model.fit(train, val_series=val, epochs=10)
new_model = DLinearModel.load_from_checkpoint(model.model_name, best=True)

Please note that specifying the dirpath and filename arguments of ModelCheckpoint will export the .ckpt at the desired path but the .pth.tar will not be exported, making this checkpoint not usable unless the .pth.tar is copied to the expected relative path (we need to investigate if it can be automated if such a callback is provided). Thus making the work_dir argument to ideal way of changing the checkpoints path.

Manually saving the model as you described is probably the most practical way to copy the checkpoints to the desired place.

Let me know if it helps.

@tRosenflanz
Copy link
Contributor Author

I think the automatic concatenation of different subpaths into work_dir and interactions with custom training confused me a bit. I am switching from pytorch_forecasting where I had to define my own trainer so the Darts way of handling it confused me at first. I see no issue with the method you provide.

@madtoinou madtoinou added q&a Frequent question & answer and removed triage Issue waiting for triaging labels May 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
q&a Frequent question & answer
Projects
None yet
Development

No branches or pull requests

2 participants