Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with TFT.forward() method #1519

Open
Bruno-TT opened this issue Feb 20, 2024 · 0 comments
Open

Issue with TFT.forward() method #1519

Bruno-TT opened this issue Feb 20, 2024 · 0 comments

Comments

@Bruno-TT
Copy link

Bruno-TT commented Feb 20, 2024

Hi guys,

My TFT hasn't been working and I think I've found the reason why. Apologies if I've misunderstood anything, please feel free to tell me a fix or explain what I'm doing wrong.

In this line of the .forward() method

embeddings_varying_decoder, decoder_sparse_weights = self.decoder_variable_selection(
    embeddings_varying_decoder,
    static_context_variable_selection[:, max_encoder_length:],
)

which looks like

def forward(self, x: Dict[str, torch.Tensor], context: torch.Tensor = None):
  if self.num_inputs > 1:
      # transform single variables
      var_outputs = []
      weight_inputs = []
      for name in self.input_sizes.keys():
          # select embedding belonging to a single input
          variable_embedding = x[name]
          if name in self.prescalers:
              variable_embedding = self.prescalers[name](variable_embedding)
          weight_inputs.append(variable_embedding)
          var_outputs.append(self.single_variable_grns[name](variable_embedding))
      var_outputs = torch.stack(var_outputs, dim=-1)

      # calculate variable weights
      flat_embedding = torch.cat(weight_inputs, dim=-1)
      sparse_weights = self.flattened_grn(flat_embedding, context)
      sparse_weights = self.softmax(sparse_weights).unsqueeze(-2)

      outputs = var_outputs * sparse_weights
      outputs = outputs.sum(dim=-1)
  else:  # for one input, do not perform variable selection but just encoding
      name = next(iter(self.single_variable_grns.keys()))
      variable_embedding = x[name]
      if name in self.prescalers:
          variable_embedding = self.prescalers[name](variable_embedding)
      outputs = self.single_variable_grns[name](variable_embedding)  # fast forward if only one variable
      if outputs.ndim == 3:  # -> batch size, time, hidden size, n_variables
          sparse_weights = torch.ones(outputs.size(0), outputs.size(1), 1, 1, device=outputs.device)  #
      else:  # ndim == 2 -> batch size, hidden size, n_variables
          sparse_weights = torch.ones(outputs.size(0), 1, 1, device=outputs.device)
  return outputs, sparse_weights

the line

name = next(iter(self.single_variable_grns.keys()))

raises a StopIteration error when self.num_inputs=0, which gets caught 26 stack frames down 😆 in _TrainingEpochLoop.run() [we are on the self.advance(data_fetcher) line] which terminates the training prematurely (I think?).

A fix would be hugely appreciated. Thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant