Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to mask or ignore target features during model training #2381

Open
yunakkano opened this issue May 12, 2024 · 13 comments
Open

How to mask or ignore target features during model training #2381

yunakkano opened this issue May 12, 2024 · 13 comments
Labels
question Further information is requested

Comments

@yunakkano
Copy link

yunakkano commented May 12, 2024

I am having a difficulty to train TSMixer model in my current project where target values wouldn't be available at the time of prediction.

Is there any option to mask or ignore past target values during model traing?

@yunakkano yunakkano changed the title How to mask or ignore target features during model training in deeplearning models How to mask or ignore target features during model training May 12, 2024
@madtoinou madtoinou added the question Further information is requested label May 13, 2024
@madtoinou
Copy link
Collaborator

Hi @yunakkano,

If you have a component that you're not interested in forecasting, you have two options:

  • use it as a covariate instead, from what you described, probably a past covariates as new values won't be available at inference time.
  • remove it entirely from the training series and don't use it for training/prediction.

@dennisbader
Copy link
Collaborator

Can you elaborate on what you mean by "the target values are not available at prediction time"?

Just for clarification:

  • the target variables are the ones that you want to forecast (and not covariates/external features that you do not want to forecast)
  • for training our torch models (neural networks), you need at least input_chunk_lenght+output_chunk_length+output_chunk_shift points per target series (for both single series or multiple / a list of series). There are two options for too short series:
    • If you work with multiple series, you could just drop the too short ones from the input series.
    • For some series (e.g. sales data) prepending some values before the start of the series can help. Let's say you have sales data for a product, but the product has been launched recently, and the sales history is not long enough yet. Here you could prepend zeros (zero sales) until the length covers the required model input time frame. Additionally, you could add a new (binary) covariate feature that flags these time periods (e.g. value 1 if the time is before the launch and 0 otherwise)

@yunakkano
Copy link
Author

yunakkano commented May 14, 2024

Hi @madtoinou and @dennisbader ,

Thank you both for your helpful responses. I appreciate your insights and suggestions.

To clarify my scenario, I have a timeseries dataset with feature columns x1, x2, x3, and y. The variable y is the target I want to predict, but it is only available after x1, x2, and x3 are observed. This means that y cannot be used during the training of the prediction model.

Assuming a time horizon t_0, t_1, ..., t_now, t_now+1, and so on, we are currently at t_now. I want to train the model with the features x1, x2, and x3 from the time section t_0 to t_now in order to predict y values from t_now+1 onwards.

To address your question, @dennisbader , regarding what I mean by "the target values are not available at prediction time," here is some pseudo code to illustrate my situation. I aim to build a TSMixerModel that takes 6 past time steps to predict 6 future time steps, but I'm unsure how to set the target variable. My understanding is that if I include y in the target, the model will be trained with y values from the past time steps (i.e. before t_now), which is not applicable in my scenario. On the other hand, if I exclude y from the target, the model will not be trained to predict y values.

(If my understanding is wrong, correct me please.)

Here is the pseudo code for reference:

model = TSMixerModel(
    input_chunk_length=6,
    output_chunk_length=6
)

model.fit(target, past_covariates=past_cov)

Given this context, could you provide more specific advice on how to handle this scenario within the Darts framework? Is there a way to mask or ignore the y values from the past time steps during model training while still training the model to predict y values from t_now+1 onwards?

Thank you again for your assistance.

Best regards,

@dennisbader
Copy link
Collaborator

Hi @yunakkano let me see if I understand your problem correctly. If you were to predict in this moment, your scenario would look something like below?

var\time t0 t1 t2 t3 t4 tnow tpred1 tpred... tpred6
y 1 2 3 4 5 - ? ? ?
x1 5 6 7 8 9 10 - - -
  • y is your target variable
  • x1 is one of your covariate features
  • tpred is the current time step after which you want to predict tpred1 until tpred6 (horizon=output_chunk_length=6 points)
  • t0 until tnow is you input window (input_chunk_length=6 points)
  • ? are the y values you want to predict
  • - means the variable is not available at the given time step

There are two ways to this.

  • with our neural networks (such as TSMixer), it can be done but you will not be able to use tnow of x1, x2, .... For this you can use output_chunk_shift=1 at model creation. This will use t0 until t4 (in other words tnow-1) from y and x* to predict tpred1 (in other words tnow+1).

    model = TSMixerModel(
        input_chunk_length=6,
        output_chunk_length=6,
        output_chunk_shift=1,
    )
    model.fit(target, past_covariates=past_cov)
    
  • with our regression models (any of them) it can be done using all available information. The configuration below will use all the information from the table. In this case we need to use the covariates as future_covariates. When predicting, make sure that your target series ends at tnow-1, and the covariates series at tnow.

    model = RegressionModel(
        lags=5,
        lags_future_covariates=[-6, -5, -4, -3, -2, -1],
        output_chunk_length=6,
        output_chunk_shift=1,
    )
    model.fit(target, future_covariates=past_cov)
    

@yunakkano
Copy link
Author

Hi @dennisbader,

Thank you for your detailed response and for defining my problem more clearly, as well as explaining the use of the output_chunk_shift option.

To further clarify, my main issue is that the target variable y is unavailable even within the 6 past time steps. The timing when y becomes available is significantly later than when x1 at tnow is observed. For example, if the intervals are in hours, the value of y at tnow might only be available several weeks later. Here is the table to illustrate my scenario:

var\time t0 t1 t2 t3 t4 tnow tpred1 tpred... tpred6
y - - - - - - ? ? ?
x1 5 6 7 8 9 10 - - -

In this context, defining an exact number for output_chunk_shift is difficult because the value of y is not available in a consistent and predictable manner relative to x1, x2, and x3. The target values for y are delayed, making it challenging to set a specific output_chunk_shift that would account for this delay during model training.

@dennisbader
Copy link
Collaborator

Hi @yunakkano, in that case, the only option is to use our regression models without using any information of the target series y as input (with lags=None).

model = RegressionModel(
    lags=None,
    lags_past_covariates=6,
    output_chunk_length=6
)
model.fit(target, past_covariates=past_cov)

@yunakkano
Copy link
Author

Hi @dennisbader,

Thank you for the recommendation to use the regression models without any information from the target series y as input.

I have a somewhat naive question: would it be possible to customize the original TSMixer model to handle my scenario by, for example, fixing y to a specific value during training to prevent the model from learning from the target?

Specifically, do you think a customization where we replace all information of the target column with zero in the forward() method of the _TSMixerModule class would be effective?

Thank you again for your assistance.

@dennisbader
Copy link
Collaborator

Hi @yunakkano, yes you can overwrite the forward pass to ignore the past target values in the input chunk.
You will have to subclass from _TSMixerModule to overwrite the forward pass and TSMixerModel to create a model from this custom module. Below you can find an example how it would work:

Note that you do need the values of y for training, as the model needs to compute the loss between the predicted and actual y in the output chunk/window/forecast horizon.

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

from darts.dataprocessing.transformers import Scaler
from darts.datasets import AirPassengersDataset
from darts.models.forecasting.pl_forecasting_module import io_processor
from darts.models.forecasting.tsmixer_model import TSMixerModel, _TSMixerModule


class _CustomTSMixerModule(_TSMixerModule):
    def __init__(self, **kwargs):
        # remember the number of target features for dropping the target input in `forward`
        self.input_target_dim = kwargs["input_dim"]
        # setup the model with zero target features
        kwargs["input_dim"] = 0
        super().__init__(**kwargs)

    @io_processor
    def forward(self, x_in) -> torch.Tensor:
        # x_past has (past target, past covariates, and historic part of future covariates)
        x_past, x_future, x_static = x_in
        # drop past target from past features
        x_past = x_past[:, :, self.input_target_dim:]
        return super().forward.__wrapped__(self, (x_past, x_future, x_static))


class CustomTSMixerModel(TSMixerModel):
    # we have to overwrite this method to return a `_CustomTSMixerModule`
    def _create_model(self, train_sample) -> nn.Module:
        # this is all copy-pasted
        (
            past_target,
            past_covariates,
            historic_future_covariates,
            future_covariates,
            static_covariates,
            future_target,
        ) = train_sample

        input_dim = past_target.shape[1]
        output_dim = future_target.shape[1]
        static_cov_dim = (
            static_covariates.shape[0] * static_covariates.shape[1]
            if static_covariates is not None
            else 0
        )
        future_cov_dim = (
            future_covariates.shape[1] if future_covariates is not None else 0
        )
        past_cov_dim = past_covariates.shape[1] if past_covariates is not None else 0
        nr_params = 1 if self.likelihood is None else self.likelihood.num_parameters

        # ---> now we return an instance of our custom module
        return _CustomTSMixerModule(
            input_dim=input_dim,
            output_dim=output_dim,
            future_cov_dim=future_cov_dim,
            past_cov_dim=past_cov_dim,
            static_cov_dim=static_cov_dim,
            nr_params=nr_params,
            hidden_size=self.hidden_size,
            ff_size=self.ff_size,
            num_blocks=self.num_blocks,
            activation=self.activation,
            dropout=self.dropout,
            norm_type=self.norm_type,
            normalize_before=self.normalize_before,
            **self.pl_module_params,
        )

Now we can use this new model as the original TSMixerModel. In this example I'll just use y as past_covariates to see if it works. At prediction time, we replace all target values with np.nan and we see that the forecast was successful, as it only used the past covariates information.

target = AirPassengersDataset().load().astype(np.float32)
target_train = target[:-6]

# scale data
scaler = Scaler()
target_train = scaler.fit_transform(target_train)
# past covariates (just use the target series so we see if it works)
past_cov = target_train.copy()

# note that you cannot use `use_reversible_instance` anymore since at prediction time
# we don't know the input target values anymore
model = CustomTSMixerModel(
    input_chunk_length=12,
    output_chunk_length=6,
    random_state=42,
)
model.fit(target_train, past_covariates=past_cov, epochs=100)

# let's overwrite all values in `target` series with NaNs and see if prediction works
vals_target = target_train.all_values(copy=False)
vals_target[:] = np.nan
preds = model.predict(n=6, series=target_train, past_covariates=past_cov)

target[-12:].plot()
scaler.inverse_transform(preds).plot()
plt.show()
image

@yunakkano
Copy link
Author

yunakkano commented May 16, 2024

Wow, @dennisbader ! I never imagined you could come up with such an amazing idea.

Thank you so much for this detailed explanation and the code example. I'm really excited to try this out.

Just to confirm, based on your implementation here:

x_past = x_past[:, :, self.input_target_dim:]
the targets (in my case y) must be located at the beginning of the input tensor and features other than targets come after the targets, correct?

Thanks again for your incredible support.

@dennisbader
Copy link
Collaborator

Hi @yunakkano, no worries.

Yes, they must be located at the beginning. This is handled automatically by our datasets, so you shouldn't have to worry about it.

@tRosenflanz
Copy link
Contributor

tRosenflanz commented May 16, 2024

I wish this issue was created 3 days ago when I wanted to do exactly the same.
In my case I didn't want target and historic_future covariates to be included so I just override them with 0s

from typing import Tuple
from darts.utils.data.inference_dataset import MixedCovariatesInferenceDataset
from darts.utils.data.sequential_dataset import MixedCovariatesSequentialDataset
from numpy import ndarray
# past_target,
# past_covariates,
# historic_future_covariates,
# future_covariates,
# static_covariates,
# future_target
### remove the past target and historic future covariates because they get concatenated into model inputs
class StrippedMixedCovariatesSequentialDataset(MixedCovariatesSequentialDataset):
    def __getitem__(self, idx) -> Tuple[ndarray, ndarray | None, ndarray | None, ndarray | None, ndarray | None, ndarray]:
        vals = list(super().__getitem__(idx))
        vals[0] = vals[0]*0
        vals[2] = vals[2]*0
        return tuple(vals)
        
class StrippedMixedCovariatesInferenceDataset(MixedCovariatesInferenceDataset):
    def __getitem__(self, idx) -> Tuple[ndarray, ndarray | None, ndarray | None, ndarray | None, ndarray | None, ndarray]:
        vals = list(super().__getitem__(idx))
        vals[0] = vals[0]*0
        vals[2] = vals[2]*0
        return tuple(vals)

and then these can be used with fit/predict_from_dataset accordingly

@dennisbader
Copy link
Collaborator

Hi @tRosenflanz, this is certainly another way to do it, just that in this case your model complexity is larger for handling the redundant information.

@SaltedfishLZX
Copy link

This discussion is really helpful. hope all models will have the option to ignore past target values

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

5 participants