You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
My TFT hasn't been working and I think I've found the reason why. Apologies if I've misunderstood anything, please feel free to tell me a fix or explain what I'm doing wrong.
def forward(self, x: Dict[str, torch.Tensor], context: torch.Tensor = None):
if self.num_inputs > 1:
# transform single variables
var_outputs = []
weight_inputs = []
for name in self.input_sizes.keys():
# select embedding belonging to a single input
variable_embedding = x[name]
if name in self.prescalers:
variable_embedding = self.prescalers[name](variable_embedding)
weight_inputs.append(variable_embedding)
var_outputs.append(self.single_variable_grns[name](variable_embedding))
var_outputs = torch.stack(var_outputs, dim=-1)
# calculate variable weights
flat_embedding = torch.cat(weight_inputs, dim=-1)
sparse_weights = self.flattened_grn(flat_embedding, context)
sparse_weights = self.softmax(sparse_weights).unsqueeze(-2)
outputs = var_outputs * sparse_weights
outputs = outputs.sum(dim=-1)
else: # for one input, do not perform variable selection but just encoding
name = next(iter(self.single_variable_grns.keys()))
variable_embedding = x[name]
if name in self.prescalers:
variable_embedding = self.prescalers[name](variable_embedding)
outputs = self.single_variable_grns[name](variable_embedding) # fast forward if only one variable
if outputs.ndim == 3: # -> batch size, time, hidden size, n_variables
sparse_weights = torch.ones(outputs.size(0), outputs.size(1), 1, 1, device=outputs.device) #
else: # ndim == 2 -> batch size, hidden size, n_variables
sparse_weights = torch.ones(outputs.size(0), 1, 1, device=outputs.device)
return outputs, sparse_weights
the line
name = next(iter(self.single_variable_grns.keys()))
raises a StopIteration error when self.num_inputs=0, which gets caught 26 stack frames down 😆 in _TrainingEpochLoop.run() [we are on the self.advance(data_fetcher) line] which terminates the training prematurely (I think?).
A fix would be hugely appreciated. Thanks a lot!
The text was updated successfully, but these errors were encountered:
Hi guys,
My TFT hasn't been working and I think I've found the reason why. Apologies if I've misunderstood anything, please feel free to tell me a fix or explain what I'm doing wrong.
In this line of the .forward() method
which looks like
the line
raises a StopIteration error when
self.num_inputs=0
, which gets caught 26 stack frames down 😆 in_TrainingEpochLoop.run()
[we are on theself.advance(data_fetcher)
line] which terminates the training prematurely (I think?).A fix would be hugely appreciated. Thanks a lot!
The text was updated successfully, but these errors were encountered: