Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skorch forwarding data columns as kwargs when using gridsearchcv #1043

Open
kotoroshinoto opened this issue Feb 5, 2024 · 4 comments
Open
Labels

Comments

@kotoroshinoto
Copy link

Working with the data in this link: https://www.kaggle.com/datasets/uciml/breast-cancer-wisconsin-data

from sklearn.utils.estimator_checks import check_estimator
from collections import OrderedDict
class MyNeuralNetwork(nn.Module):
    def __init__(self, activation_func=nn.ReLU, hidden_layers=(512,1024,512,256)):
        super().__init__()
        n_inputs: int = 30
        n_outputs: int = 1
        layers = OrderedDict()
        topology = [n_inputs] + list(hidden_layers) + [n_outputs]
        activ_layers_added= 0
        linear_layers_added = 0
        for i in range(len(topology)-1):
            if i > 0:
                layers[f"{activation_func.__name__}_{activ_layers_added}"] = activation_func()
                activ_layers_added += 1
            in_size = topology[i]
            out_size = topology[i+1]
            layers[f"{nn.Linear.__name__}_{linear_layers_added}"] = nn.Linear(in_size, out_size)
            linear_layers_added += 1
        # layers['softmax'] = nn.Softmax(dim=-1)
        self.linear_relu_stack = nn.Sequential(layers)

    def forward(self, x):
        logits = self.linear_relu_stack(x.to(torch.float32))
        return logits
mynn = MyNeuralNetwork()
print(mynn)
sknet = NeuralNetBinaryClassifier(
    MyNeuralNetwork,
    module__hidden_layers=(512, 1024, 512, 256),
    max_epochs=100,
    lr=0.1,
    # Shuffle training data on each epoch
    iterator_train__shuffle=True,
)
pipeline = Pipeline([
    ('robust_scaler', RobustScaler()),
    ("cls", sknet)
])
pipeline.get_params()

I am able to run this:

pipeline.fit(X_train, y_train)
test_preds = pipeline.predict(X_test)
print(classification_report(y_test, test_preds))

When I try to run this I get an error:

sknet.set_params(train_split=False, verbose=0)
params = {
    'lr': [0.01, 0.02],
    'max_epochs': [10, 20],
    'module__hidden_layers': [
        (512,1024,512,256),
        (512,512,512,512),
        (2048,1024,512,256,128,64)
    ],
}
gs = GridSearchCV(
    sknet,
    params,
    scoring = 'f1',
    error_score = 'raise',
    refit=False,
    cv=StratifiedShuffleSplit(n_splits=8, test_size=0.15),
    verbose=2)

gs.fit(X_train, y_train)
print("best score: {:.3f}, best params: {}".format(gs.best_score_, gs.best_params_))

TypeError: MyNeuralNetwork.forward() got an unexpected keyword argument 'radius_mean'

It should be forwarding this in as x, not using each column by name. (these are pandas dataframes or series, for X and Y)

@kotoroshinoto
Copy link
Author

Here is the error stack trace:
https://pastebin.com/3ch44t75

@kotoroshinoto
Copy link
Author

kotoroshinoto commented Feb 5, 2024

when it runs the command @ skorch/net.py on line: 1182: dataset_train, dataset_valid = sknet.get_split_datasets(X_train, y_train_ints)
this is where the data turns into a mapping.

then later on it merges this and feeds it as kwargs to the forward function. starting with the conditional on line 1518

1518 if isinstance(x, Mapping):
1519     x_dict = self._merge_x_and_fit_params(x, fit_params)
1520     return self.module_(**x_dict)
1521 return self.module_(x, **fit_params)

Since X is a mapping due to the transformation earlier, it merges and fits and uses the module **dict version instead of the version that forwards x

@kotoroshinoto
Copy link
Author

kotoroshinoto commented Feb 5, 2024

Using to_numpy on the pandas objects and a different scoring method name doesn't trigger this, which is very strange.

@BenjaminBossan
Copy link
Collaborator

Indeed, when you pass a pandas DataFrame as input to skorch, it will convert it to a dict, with each column corresponding to one value in the dict. This is because PyTorch cannot deal with DataFrames, so we need to convert them to something more suitable.

Using to_numpy on the pandas objects and a different scoring method name doesn't trigger this, which is very strange.

When you pass a numpy array instead of a df, we don't encounter the aforementioned problem, which is why it works. Note, however, that this may not be what you want. For instance, if the df contains categorical data, you surely don't want to treat it like just numerical data.

We have a helper class that takes care of some of this: DataFrameTransformer. Maybe this is something that would suite your needs. Otherwise, there is no easy solution to your issue: You need to do some feature engineering/transformation/scaling to make the data suitable for use with a neural net, then package the data either as a numpy array (if it's homogeneous) or as a dict of arrays/tensors.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants