Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Activating (deactivating) callbacks at specific epochs or milestones and SequentialLR #1049

Open
AdamCoxson opened this issue Feb 26, 2024 · 1 comment

Comments

@AdamCoxson
Copy link

AdamCoxson commented Feb 26, 2024

Hi,

I'm using ReduceLROnPlateau and Early Stopping with my own custom Concordance Correlation monitor metric (RhoC). I want my networks to train with .fit with the callbacks inactivate for the first 50 or 100 epochs, then activate. As far as I'm aware, there aren't arguments to do this, so I created my own version of the lr_scheduler module and modified the LRScheduler(Callback) class to take in an argument called epoch_start.

I've shown the modified class functions at the bottom. (def __init__, def kwargs, def _on_epoch_end)

This does the job for me for now. I can do a similar thing to modify EarlyStopping. I just wanted to check in and see if there is actually a way to do this, or a work around without needing to modify the source code. I looked into SequentialLR and could apply it in a similar way as in this post, just with ConstantLR for the first 50 epochs, which would work for normal pyTorch and if I was manually coding my own fit function, but I'm unsure how to integrate SequentialLR with skorch's fit and callbacks system.

So my questions are:

  1. Other than modifying source code, how can I add an activation delay to callbacks based on epoch number. If there is a way that already exists with skorchs .fit function.
  2. How could I implement SequentialLR or an equivalent set of learning rate schedulers in callbacks?

This is more for interest as modifiying the source code works for me. Any pointers let me know :)

    def __init__(self,
                 policy='WarmRestartLR',
                 monitor='train_loss',
                 event_name="event_lr",
                 step_every='epoch',
                 epoch_start=1, 
                 **kwargs):
        self.policy = policy
        self.monitor = monitor
        self.event_name = event_name
        self.step_every = step_every
        self.epoch_start=epoch_start
        # if 'epoch_start' in kwargs:
        #     del kwargs['epoch_start']
        vars(self).update(kwargs)

  def kwargs(self):
      # These are the parameters that are passed to the
      # scheduler. Parameters that don't belong there must be
      # excluded.
      excluded = ('policy', 'monitor', 'event_name', 'step_every', 'epoch_start')
      kwargs = {key: val for key, val in vars(self).items()
                if not (key in excluded or key.endswith('_'))}
      return kwargs`

def on_epoch_end(self, net, **kwargs):
        if self.step_every != 'epoch':
            return
        if isinstance(self.lr_scheduler_, ReduceLROnPlateau):
            if callable(self.monitor):
                score = self.monitor(net)
            else:
                try:
                    score = net.history[-1, self.monitor]
                except KeyError as e:
                    raise ValueError(
                        f"'{self.monitor}' was not found in history. A "
                        f"Scoring callback with name='{self.monitor}' "
                        "should be placed before the LRScheduler callback"
                    ) from e
            n_epoch=len(net.history)
            if n_epoch <= self.epoch_start:
                print("Not starting lr scheduler yet")
                return
            else:
                self._step(net, self.lr_scheduler_, score=score)
            # ReduceLROnPlateau does not expose the current lr so it can't be recorded
        else:
            if (
                    (self.event_name is not None)
                    and hasattr(self.lr_scheduler_, "get_last_lr")
            ):
                net.history.record(self.event_name, self.lr_scheduler_.get_last_lr()[0])
            self._step(net, self.lr_scheduler_)

My callbacks are defined like:

r_score = EpochScoring(scoring=rhoc_score, lower_is_better=False, name='valid_rhoc')
r_early_stop = EarlyStopping(monitor='valid_rhoc',patience=30, lower_is_better=False, threshold_mode='abs',threshold=1e-4,load_best=True)
lr_plateau = LRScheduler(epoch_start=50,monitor='valid_loss',policy='ReduceLROnPlateau',factor=0.25,patience=10, threshold=1e-3, verbose=True)
#scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2])
callbacks=[r_score, r_early_stop, lr_plateau]
@BenjaminBossan
Copy link
Collaborator

In general, I think your approach is valid and doing such modifications to skorch classes is an absolutely perfect approach.

I wonder if the on_epoch_end method could not be simplified by using an early return like this:

        def on_epoch_end(self, net, **kwargs):
            if n_epoch <= self.epoch_start:
                print("Not starting lr scheduler yet")
                return
            return super().on_epoch_end(net, **kwargs)

Regarding SequentialLR, I haven't checked in detail, but wouldn't something like this work?

lr_scheduler_seq = LRScheduler(policy=SequentialLR, ...)  # add more args to SequentialLR here

Check out the docs here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants